In [147]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_ENABLE_X64"] = "True"

import lagrangebench
import haiku as hk
import numpy as np
import matplotlib.pyplot as plt
import pickle
import matplotlib.animation as animation

In [148]:
ht2d_train = lagrangebench.data.HT2D("train", extra_seq_length=5)  # extra_seq_length=5 will be clear later
ht2d_valid = lagrangebench.data.HT2D("valid", extra_seq_length=20)

print(
    f"This is a {ht2d_train.metadata['dim']}D dataset "
    f"called {ht2d_train.metadata['case']}.\n"
    f"Train snapshot have shape {ht2d_train[0][0].shape} (n_nodes, seq_len, xy pos).\n"
    f"Val snapshot have shape {ht2d_valid[0][0].shape} (n_nodes, rollout, xy pos).\n"
)

This is a 2D dataset called HT.
Train snapshot have shape (950, 12, 2) (n_nodes, seq_len, xy pos).
Val snapshot have shape (950, 26, 2) (n_nodes, rollout, xy pos).



In [143]:
ht2d_train[439][1][0:864]

array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [149]:
def gns(x):
    return lagrangebench.GNS(
        particle_dimension=ht2d_train.metadata["dim"],
        latent_size=16,
        blocks_per_step=2,
        num_mp_steps=4,
        particle_type_embedding_size=8,
    )(x)

In [4]:
gns = hk.without_apply_rng(hk.transform_with_state(gns))

In [5]:
noise_std = 3e-4

pf_config = lagrangebench.PushforwardConfig(
    steps=[-1, 500, 700],  # training steps to unlock the relative stage
    unrolls=[0, 2, 5],  # number of unroll steps per stage
    probs=[7, 2, 1],  # relative probabilities to unroll to the relative stage
)

In [6]:
"""Case setup functions."""

from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
from jax import Array, jit, lax, vmap
from jax_md import space
from jax_md.dataclasses import dataclass, static_field
from jax_md.partition import NeighborList, NeighborListFormat

from lagrangebench.data.utils import get_dataset_stats
from lagrangebench.defaults import defaults
from lagrangebench.train.strats import add_gns_noise

from lagrangebench.case_setup.features import FeatureDict, TargetDict, physical_feature_builder
from lagrangebench.case_setup.partition import neighbor_list

TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList]
EvalCaseOut = Tuple[FeatureDict, NeighborList]
SampleIn = Tuple[jnp.ndarray, jnp.ndarray]

AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut]
AllocateEvalFn = Callable[[SampleIn], EvalCaseOut]

PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut]
PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut]

IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]

In [7]:
@dataclass
class CaseSetupFn:
    """Dataclass that contains all functions required to setup the case and simulate.

    Attributes:
        allocate: AllocateFn, runs the preprocessing without having a NeighborList as
            input.
        preprocess: PreprocessFn, takes positions from the dataloader, computes
            velocities, adds random-walk noise if needed, then updates the neighbor
            list, and return the inputs to the neural network as well as the targets.
        allocate_eval: AllocateEvalFn, same as allocate, but without noise addition
            and without targets.
        preprocess_eval: PreprocessEvalFn, same as allocate_eval, but jit-able.
        integrate: IntegrateFn, semi-implicit Euler integrations step respecting
            all boundary conditions.
        displacement: space.DisplacementFn, displacement function aware of boundary
            conditions (periodic on non-periodic).
        normalization_stats: Dict, normalization statisticss for input velocities and
            output acceleration.
    """

    allocate: AllocateFn = static_field()
    preprocess: PreprocessFn = static_field()
    allocate_eval: AllocateEvalFn = static_field()
    preprocess_eval: PreprocessEvalFn = static_field()
    integrate: IntegrateFn = static_field()
    displacement: space.DisplacementFn = static_field()
    normalization_stats: Dict = static_field()

In [8]:
def case_builder(
    box: Tuple[float, float, float],
    metadata: Dict,
    input_seq_length: int,
    isotropic_norm: bool = defaults.isotropic_norm,
    noise_std: float = defaults.noise_std,
    external_force_fn: Optional[Callable] = None,
    magnitude_features: bool = defaults.magnitude_features,
    neighbor_list_backend: str = defaults.neighbor_list_backend,
    neighbor_list_multiplier: float = defaults.neighbor_list_multiplier,
    dtype: jnp.dtype = defaults.dtype,
):
    """Set up a CaseSetupFn that contains every required function besides the model.

    Inspired by the `partition.neighbor_list` function in JAX-MD.

    The core functions are:
        * allocate, allocate memory for the neighbors list.
        * preprocess, update the neighbors list.
        * integrate, semi-implicit Euler respecting periodic boundary conditions.

    Args:
        box: Box xyz sizes of the system.
        metadata: Dataset metadata dictionary.
        input_seq_length: Length of the input sequence.
        isotropic_norm: Whether to normalize dimensions equally.
        noise_std: Noise standard deviation.
        external_force_fn: External force function.
        magnitude_features: Whether to add velocity magnitudes in the features.
        neighbor_list_backend: Backend of the neighbor list.
        neighbor_list_multiplier: Capacity multiplier of the neighbor list.
        dtype: Data type.
    """
    normalization_stats = get_dataset_stats(metadata, isotropic_norm, noise_std)

    # apply PBC in all directions or not at all
    if jnp.array(metadata["periodic_boundary_conditions"]).any():
        displacement_fn, shift_fn = space.periodic(side=jnp.array(box))
        #displacement_fn(Ra, Rb, **kwargs): Computes displacements between pairs of particles. Ra and Rb should be ndarrays of shape [spatial_dim]. Returns an ndarray of shape [spatial_dim]. 
        #shift_fn(R, dR, **kwargs): Moves points at position R by an amount dR.
    else:
        displacement_fn, shift_fn = space.free()

    displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0))

    neighbor_fn = neighbor_list(
        displacement_fn,
        jnp.array(box),
        backend=neighbor_list_backend,
        r_cutoff=metadata["default_connectivity_radius"],
        capacity_multiplier=neighbor_list_multiplier,
        mask_self=False,
        format=NeighborListFormat.Sparse,
        num_particles_max=metadata["num_particles_max"],
        pbc=metadata["periodic_boundary_conditions"],
    )

    feature_transform = physical_feature_builder(
        bounds=metadata["bounds"],
        normalization_stats=normalization_stats,
        connectivity_radius=metadata["default_connectivity_radius"],
        displacement_fn=displacement_fn,
        pbc=metadata["periodic_boundary_conditions"],
        magnitude_features=magnitude_features,
        external_force_fn=external_force_fn,
    )

    def _compute_target(pos_input: jnp.ndarray) -> TargetDict:
        # displacement(r1, r2) = r1-r2  # without PBC

        current_velocity = displacement_fn_set(pos_input[:, 1], pos_input[:, 0])
        next_velocity = displacement_fn_set(pos_input[:, 2], pos_input[:, 1])
        current_acceleration = next_velocity - current_velocity

        acc_stats = normalization_stats["acceleration"]
        normalized_acceleration = (
            current_acceleration - acc_stats["mean"]
        ) / acc_stats["std"]

        vel_stats = normalization_stats["velocity"]
        normalized_velocity = (next_velocity - vel_stats["mean"]) / vel_stats["std"]
        return {
            "acc": normalized_acceleration,
            "vel": normalized_velocity,
            "pos": pos_input[:, -1],
        }

    def _preprocess(
        sample: Tuple[jnp.ndarray, jnp.ndarray],
        neighbors: Optional[NeighborList] = None,
        is_allocate: bool = False,
        mode: str = "train",
        **kwargs,  # key, noise_std, unroll_steps
    ) -> Union[TrainCaseOut, EvalCaseOut]:
        pos_input = jnp.asarray(sample[0], dtype=dtype)
        particle_type = jnp.asarray(sample[1])

        if mode == "train":
            key, noise_std = kwargs["key"], kwargs["noise_std"]
            unroll_steps = kwargs["unroll_steps"]
            if pos_input.shape[1] > 1:
                key, pos_input = add_gns_noise(
                    key, pos_input, particle_type, input_seq_length, noise_std, shift_fn
                )

        # allocate the neighbor list
        most_recent_position = pos_input[:, input_seq_length - 1]
        num_particles = (particle_type != -1).sum()
        if is_allocate:
            neighbors = neighbor_fn.allocate(
                most_recent_position, num_particles=num_particles
            )
        else:
            neighbors = neighbors.update(
                most_recent_position, num_particles=num_particles
            )

        # selected features
        features = feature_transform(pos_input[:, :input_seq_length], neighbors)

        if mode == "train":
            # compute target acceleration. Inverse of postprocessing step.
            # the "-2" is needed because we need the most recent position and one before
            slice_begin = (0, input_seq_length - 2 + unroll_steps, 0)
            slice_size = (pos_input.shape[0], 3, pos_input.shape[2])

            target_dict = _compute_target(
                lax.dynamic_slice(pos_input, slice_begin, slice_size)
            )
            return key, features, target_dict, neighbors
        if mode == "eval":
            return features, neighbors

    def allocate_fn(key, sample, noise_std=0.0, unroll_steps=0):
        return _preprocess(
            sample,
            key=key,
            noise_std=noise_std,
            unroll_steps=unroll_steps,
            is_allocate=True,
        )

    @jit
    def preprocess_fn(key, sample, noise_std, neighbors, unroll_steps=0):
        return _preprocess(
            sample, neighbors, key=key, noise_std=noise_std, unroll_steps=unroll_steps
        )

    def allocate_eval_fn(sample):
        return _preprocess(sample, is_allocate=True, mode="eval")

    @jit
    def preprocess_eval_fn(sample, neighbors):
        return _preprocess(sample, neighbors, mode="eval")

    @jit
    def integrate_fn(normalized_in, position_sequence):
        """Euler integrator to get position shift."""
        assert any([key in normalized_in for key in ["pos", "vel", "acc"]])

        if "pos" in normalized_in:
            # Zeroth euler step
            return normalized_in["pos"]
        else:
            most_recent_position = position_sequence[:, -1]
            if "vel" in normalized_in:
                # invert normalization
                velocity_stats = normalization_stats["velocity"]
                new_velocity = velocity_stats["mean"] + (
                    normalized_in["vel"] * velocity_stats["std"]
                )
            elif "acc" in normalized_in:
                # invert normalization.
                acceleration_stats = normalization_stats["acceleration"]
                acceleration = acceleration_stats["mean"] + (
                    normalized_in["acc"] * acceleration_stats["std"]
                )
                # Second Euler step
                most_recent_velocity = displacement_fn_set(
                    most_recent_position, position_sequence[:, -2]
                )
                new_velocity = most_recent_velocity + acceleration  # * dt = 1

            # First Euler step
            return shift_fn(most_recent_position, new_velocity)

    return CaseSetupFn(
        allocate_fn,
        preprocess_fn,
        allocate_eval_fn,
        preprocess_eval_fn,
        integrate_fn,
        displacement_fn,
        normalization_stats,
    )

In [86]:
bounds = np.array([[0.0,1.0],[0.0,0.38]])
#bounds = np.array(ht2d_train.metadata["bounds"])
box = bounds[:, 1] - bounds[:, 0]

box
#bounds[:, 1]

array([1.  , 0.38])

In [87]:
displacement_fn, shift_fn = space.periodic(side=jnp.array(box))
displacement_fn


<function jax_md.space.periodic.<locals>.displacement_fn(Ra: jax.Array, Rb: jax.Array, perturbation: Optional[jax.Array] = None, **unused_kwargs) -> jax.Array>

In [41]:
displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0))

In [88]:
Ra = jnp.array([0.3, 0.3])
Rb = jnp.array([0.1, 0.1])
displacement = displacement_fn(Ra, Rb)

    # Get the shifted position vector
#shifted_pos = shift_fn(Rb)
    
print("Displacement:", displacement)
#print("Shifted Position:", shifted_pos)

Displacement: [ 0.2  -0.18]


In [10]:
ht2d_case = case_builder(
    box=box,  # (x,y) array with the world size along each axis. (1.0, 1.0) for 2D TGV
    metadata=ht2d_train.metadata,  # metadata dictionary
    input_seq_length=6,  # number of consecutive time steps fed to the model
    isotropic_norm=False,  # whether to normalize each dimension independently
    noise_std=noise_std,  # noise standard deviation used by the random-walk noise
)

In [11]:
ht2d_case

CaseSetupFn(allocate=<function case_builder.<locals>.allocate_fn at 0x7f2b1b0151b0>, preprocess=<PjitFunction of <function case_builder.<locals>.preprocess_fn at 0x7f2b1b015240>>, allocate_eval=<function case_builder.<locals>.allocate_eval_fn at 0x7f2b1b015630>, preprocess_eval=<PjitFunction of <function case_builder.<locals>.preprocess_eval_fn at 0x7f2b1b0156c0>>, integrate=<PjitFunction of <function case_builder.<locals>.integrate_fn at 0x7f2b1b015ab0>>, displacement=<function periodic.<locals>.displacement_fn at 0x7f2b1b1b20e0>, normalization_stats={'acceleration': {'mean': Array([4.47690509e-05, 1.96038036e-05], dtype=float64), 'std': Array([0.00252051, 0.00134007], dtype=float64)}, 'velocity': {'mean': Array([ 3.21281608e-03, -2.85042406e-05], dtype=float64), 'std': Array([0.00373654, 0.00117393], dtype=float64)}})

In [12]:
trainer = lagrangebench.Trainer(
    model=gns,
    case=ht2d_case,
    data_train=ht2d_train,
    data_valid=ht2d_valid,
    pushforward=pf_config,
    noise_std=noise_std,
    metrics=["mse"],
    n_rollout_steps=20,
    eval_n_trajs=1,
    lr_start=5e-4,
    log_steps=10,
    eval_steps=50,
    batch_size_infer=1,
)

params, state, _ = trainer(step_max=100)



000, train/loss: 1.64200.
010, train/loss: 0.49071.
020, train/loss: 0.22857.
030, train/loss: 0.13987.
040, train/loss: 0.16064.
050, train/loss: 0.45200.
(eval) Reallocate neighbors list at step 1




(eval) From (2, 8098) to (2, 8258)
(eval) Reallocate neighbors list at step 6
(eval) From (2, 8258) to (2, 8623)
(eval) Reallocate neighbors list at step 13
(eval) From (2, 8623) to (2, 9131)
{'val/loss': 0.0008332215030598735, 'val/mse1': 1.538714813370024e-05, 'val/mse10': 0.0003000804581765784, 'val/mse5': 0.00014142839415986165, 'val/stdloss': 0.0, 'val/stdmse1': 0.0, 'val/stdmse10': 0.0, 'val/stdmse5': 0.0}
060, train/loss: 0.20161.
070, train/loss: 0.45718.
080, train/loss: 0.11952.
090, train/loss: 1.35452.
100, train/loss: 0.07526.
(eval) Reallocate neighbors list at step 1
(eval) From (2, 8098) to (2, 8233)
(eval) Reallocate neighbors list at step 6
(eval) From (2, 8233) to (2, 8678)
(eval) Reallocate neighbors list at step 10




(eval) From (2, 8678) to (2, 8936)
(eval) Reallocate neighbors list at step 16
(eval) From (2, 8936) to (2, 9191)
{'val/loss': 0.0005722186507974004, 'val/mse1': 1.4546645693945957e-05, 'val/mse10': 0.00027800085150634073, 'val/mse5': 0.00013646225000015484, 'val/stdloss': 0.0, 'val/stdmse1': 0.0, 'val/stdmse10': 0.0, 'val/stdmse5': 0.0}
