* If you're running this on Google Colab, please uncomment and run the cell below.

In [1]:
 !pip install optax
 !pip install flax



In [1]:
%matplotlib inline
import os
import time
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import trange
from jax import jvp, value_and_grad, vmap, lax
from flax import linen as nn
from typing import Sequence
from functools import partial
import numpy as np

## 1. SPINN

In [11]:
# forward function
class SPINN(nn.Module):
    """
    def input_encoding(self, t, x, y, L_x=1, L_y=1, M_x=1, M_y=1, M_t=1):
        w_x = 2.0 * jnp.pi / L_x
        w_y = 2.0 * jnp.pi / L_y
        k_x = jnp.arange(1, M_x + 1)
        k_y = jnp.arange(1, M_y + 1)
        k_xx, k_yy = jnp.meshgrid(k_x, k_y)
        k_xx = k_xx.flatten()
        k_yy = k_yy.flatten()
        k_t = jnp.power(10.0, jnp.arange(0, M_t + 1))
        out = jnp.hstack([1,  # Ensure scalar is wrapped as an array
            k_t * t,
            jnp.cos(k_x * w_x * x), jnp.cos(k_y * w_y * y),
            jnp.sin(k_x * w_x * x), jnp.sin(k_y * w_y * y),
            jnp.cos(k_xx * w_x * x) * jnp.cos(k_yy * w_y * y),
            jnp.cos(k_xx * w_x * x) * jnp.sin(k_yy * w_y * y),
            jnp.sin(k_xx * w_x * x) * jnp.cos(k_yy * w_y * y),
            jnp.sin(k_xx * w_x * x) * jnp.sin(k_yy * w_y * y)
        ])
        return out

    features: Sequence[int]

    @nn.compact
    def __call__(self, x, y, z):
        outputs = []
        input_list = []
        for i in range(len(x)):
            input = self.input_encoding(x[i], y[i], z[i])  # Encode input
            input_list.append(input)  # Add to the list

        # Stack all inputs along a new axis (e.g., axis=0 for batch dimension)
        inputs = jnp.stack(input_list, axis=0)
        print("inputs encoding shape: ", inputs.shape)

        init = nn.initializers.glorot_normal()
        for j in range(inputs.shape[1]):
            X = inputs[:, j].reshape(-1, 1)
            print("X shape before", X.shape)
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = jax.nn.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)
            print("X shape after", X.shape)
            outputs+=[jnp.transpose(X, (1, 0))]

        # Stack the outputs into a tensor

        result = outputs[0]
        for output in outputs[1:]:
            result = jnp.einsum('...a,...b->...ab', result, output)

        print("result shape", result.shape)
        return result

    """
    features: Sequence[int]

    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        #print("x shape", x.shape)
        init = nn.initializers.glorot_normal()
        for X in inputs:
            #print("X shape before", X.shape)
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)

            #print("X shape after", X.shape)
            outputs += [jnp.transpose(X, (1, 0))]
        xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])
        result = jnp.einsum('fxy, fz->xyz', xy, outputs[-1])
        #print("result shape", result.shape)
        return result


# hessian-vector product
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# loss function
def spinn_loss_klein_gordon3d(apply_fn, *train_data, epsilon, lam):
    """
    def residual_loss(params, t, x, y, source_term):
        # calculate u
        u = apply_fn(params, t, x, y)
        # tangent vector dx/dx
        # assumes t, x, y have same shape (very important)
        v = jnp.ones(t.shape)
        # 2nd derivatives of u
        utt = hvp_fwdfwd(lambda t: apply_fn(params, t, x, y), (t,), (v,))
        uxx = hvp_fwdfwd(lambda x: apply_fn(params, t, x, y), (x,), (v,))
        uyy = hvp_fwdfwd(lambda y: apply_fn(params, t, x, y), (y,), (v,))
        return jnp.mean((utt - uxx - uyy + u**2 - source_term)**2)

    def residuals_and_weights(params, t, x, y, source_term, tol=epsilon):
        # Vectorized computation of u and its derivatives
        u = vmap(lambda t_i: apply_fn(params, t_i, x, y))(t)
        v = jnp.ones_like(t)

        # Compute all derivatives using vmap
        utt = vmap(lambda t_i: hvp_fwdfwd(
            lambda t: apply_fn(params, t, x, y), (t_i,), (jnp.ones_like(t_i),)
        ))(t)

        uxx = vmap(lambda t_i: hvp_fwdfwd(
            lambda x: apply_fn(params, t_i, x, y), (x,), (jnp.ones_like(x),)
        ))(t)

        uyy = vmap(lambda t_i: hvp_fwdfwd(
            lambda y: apply_fn(params, t_i, x, y), (y,), (jnp.ones_like(y),)
        ))(t)

        # Compute residuals for all time steps at once
        residuals = utt - uxx - uyy + u**2 - source_term

        # Compute loss at each time step
        L_t = jnp.mean(residuals**2, axis=1)

        # Create matrix M for cumulative sum (lower triangular)
        n_t = L_t.shape[0]
        M = jnp.tril(jnp.ones((n_t, n_t)))

        # Compute weights using matrix multiplication
        W = lax.stop_gradient(jnp.exp(-tol * (M @ L_t)))
        print("L_t = ", L_t)
        print("W = ", W)
        return L_t, W

    """
    def causal_residual_loss(params, t, x, y, source_term, epsilon):
    # Calculate weights
        # Calculate causal weighted residual loss
        n_t = t.shape[0]
        residuals = []
        weights = []
        cumulative_loss = 0

        # Compute residual at time step i
        utt= hvp_fwdfwd(lambda t: apply_fn(params, t, x, y), (t,), (jnp.ones_like(t),))
        uxx = hvp_fwdfwd(lambda x: apply_fn(params, t, x, y), (x,), (jnp.ones_like(x),))
        uyy = hvp_fwdfwd(lambda y: apply_fn(params, t, x, y), (y,), (jnp.ones_like(y),))
        u = apply_fn(params, t, x, y)
        r_pred = utt - uxx - uyy + u**2 - source_term

        L_t = jnp.mean(r_pred**2, axis = 1)

        M = np.triu(np.ones((n_t, n_t)), k=1).T
        W = lax.stop_gradient(jnp.exp(-epsilon*(M @L_t)))
        # Compute loss for the current time step
        residual_loss = jnp.mean(W * L_t)

        return residual_loss

    def initial_loss(params, t, x, y, u):
        return jnp.mean((apply_fn(params, t, x, y) - u)**2)

    def boundary_loss(params, t, x, y, u):
        loss = 0.
        for i in range(4):
            loss += (1/4.) * jnp.mean((apply_fn(params, t[i], x[i], y[i]) - u[i])**2)
        return loss

    # unpack data
    tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub = train_data

    def total_loss(params):
        # Compute residuals and weights
        loss_res = causal_residual_loss(params, tc, xc, yc, uc, epsilon)
        #print("loss_res shape", loss_res.shape)
        # Initial and boundary losses
        loss_init = initial_loss(params, ti, xi, yi, ui)
        #print("loss_init shape", loss_init.shape)

        loss_bound = boundary_loss(params, tb, xb, yb, ub)
        #print("loss_bound shape", loss_bound.shape)
        # Total loss
        return loss_res + lam*loss_init + loss_bound

    return total_loss


# optimizer step function
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

## 2. Data generator

In [3]:
# 2d time-dependent klein-gordon exact u
def _klein_gordon3d_exact_u(t, x, y):
    return (x + y) * jnp.cos(2*t) + (x * y) * jnp.sin(2*t)


# 2d time-dependent klein-gordon source term
def _klein_gordon3d_source_term(t, x, y):
    u = _klein_gordon3d_exact_u(t, x, y)
    return u**2 - 4*u


# train data
def spinn_train_generator_klein_gordon3d(nc, key):
    keys = jax.random.split(key, 3)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=10.)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)
    tc_mesh, xc_mesh, yc_mesh = jnp.meshgrid(tc.ravel(), xc.ravel(), yc.ravel(), indexing='ij')
    uc = _klein_gordon3d_source_term(tc_mesh, xc_mesh, yc_mesh)
    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    ti_mesh, xi_mesh, yi_mesh = jnp.meshgrid(ti.ravel(), xi.ravel(), yi.ravel(), indexing='ij')
    ui = _klein_gordon3d_exact_u(ti_mesh, xi_mesh, yi_mesh)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc]
    xb = [jnp.array([[-1.]]), jnp.array([[1.]]), xc, xc]
    yb = [yc, yc, jnp.array([[-1.]]), jnp.array([[1.]])]
    ub = []
    for i in range(4):
        tb_mesh, xb_mesh, yb_mesh = jnp.meshgrid(tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), indexing='ij')
        ub += [_klein_gordon3d_exact_u(tb_mesh, xb_mesh, yb_mesh)]
    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub


# test data
def spinn_test_generator_klein_gordon3d(nc_test):
    t = jnp.linspace(0, 10, nc_test)
    x = jnp.linspace(-1, 1, nc_test)
    y = jnp.linspace(-1, 1, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')
    u_gt = _klein_gordon3d_exact_u(tm, xm, ym)
    t = t.reshape(-1, 1)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    return t, x, y, u_gt, tm, xm, ym

## 3. Utils

In [5]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)

def plot_klein_gordon3d(t, x, y, u):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(t, x, y, c=u, s=0.5, cmap='seismic')
    ax.set_title('U(t, x, y)', fontsize=20)
    ax.set_xlabel('t', fontsize=18, labelpad=10)
    ax.set_ylabel('x', fontsize=18, labelpad=10)
    ax.set_zlabel('y', fontsize=18, labelpad=10)
    plt.show()

In [6]:
import numpy as np

## 4. Main function

In [7]:
def main(NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER, EPSILON, LAM):
    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # feature sizes
    feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))

    # make & init model
    model = SPINN(feat_sizes)
    params = model.init(subkey, jnp.ones((NC, 1)), jnp.ones((NC, 1)), jnp.ones((NC, 1)))

    # optimizer
    optim = optax.adam(LR)
    state = optim.init(params)

    # dataset
    key, subkey = jax.random.split(key, 2)
    train_data = spinn_train_generator_klein_gordon3d(NC, subkey)
    t, x, y, u_gt, tm, xm, ym = spinn_test_generator_klein_gordon3d(NC_TEST)

    # forward & loss function
    apply_fn = jax.jit(model.apply)
    loss_fn = spinn_loss_klein_gordon3d(apply_fn, *train_data, epsilon = EPSILON, lam = LAM)

    @jax.jit
    def train_one_step(params, state):
        # compute loss and gradient
        loss, gradient = value_and_grad(loss_fn)(params)
        # update state
        params, state = update_model(optim, gradient, params, state)
        return loss, params, state

    # training
    loss_log = np.array([])
    error_log = np.array([])

    start = time.time()
    for e in trange(1, EPOCHS+1):
        # single run
        loss, params, state = train_one_step(params, state)
        if e % LOG_ITER == 0:
            u = apply_fn(params, t, x, y)

            loss_log = np.append(loss_log, loss)

            error = relative_l2(u, u_gt)

            error_log = np.append(error_log, error)

            print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')
    end = time.time()
    print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')

    print('Solution:')

    np.save(f"loss_log.npy", loss_log)
    np.save("error_log.npy", error_log)

    u = apply_fn(params, t, x, y)
    plot_klein_gordon3d(tm, xm, ym, u)




## 5. Run!

In [None]:
main(NC=64, NI=64, NB=64, NC_TEST=100, SEED=444, LR=1e-3, EPOCHS=50000, N_LAYERS=4, FEATURES=64, LOG_ITER=100, EPSILON = 1e-3, LAM = 1000)

  3%|▎         | 135/5000 [00:11<02:41, 30.11it/s]

Epoch: 100/5000 --> loss: 7.36400366, error: 1.00068188


  5%|▍         | 234/5000 [00:16<02:54, 27.36it/s]

Epoch: 200/5000 --> loss: 6.52692032, error: 0.98213446


  7%|▋         | 335/5000 [00:20<02:09, 35.95it/s]

Epoch: 300/5000 --> loss: 6.14153576, error: 0.97260010


  9%|▊         | 435/5000 [00:23<02:08, 35.62it/s]

Epoch: 400/5000 --> loss: 5.90241337, error: 0.96486682


 11%|█         | 535/5000 [00:27<02:06, 35.31it/s]

Epoch: 500/5000 --> loss: 5.74668550, error: 0.96326286


 13%|█▎        | 635/5000 [00:32<02:24, 30.11it/s]

Epoch: 600/5000 --> loss: 5.62800694, error: 0.97067118


 15%|█▍        | 735/5000 [00:35<01:59, 35.66it/s]

Epoch: 700/5000 --> loss: 5.39012432, error: 0.98729521


 17%|█▋        | 835/5000 [00:39<01:53, 36.82it/s]

Epoch: 800/5000 --> loss: 5.16672611, error: 0.96125859


 19%|█▊        | 934/5000 [00:43<02:32, 26.73it/s]

Epoch: 900/5000 --> loss: 5.07778120, error: 0.99210238


 21%|██        | 1035/5000 [00:47<01:50, 35.85it/s]

Epoch: 1000/5000 --> loss: 4.92305279, error: 1.06662786


 23%|██▎       | 1135/5000 [00:51<01:45, 36.77it/s]

Epoch: 1100/5000 --> loss: 4.71471024, error: 1.09404588


 25%|██▍       | 1235/5000 [00:54<01:46, 35.48it/s]

Epoch: 1200/5000 --> loss: 5.63543510, error: 1.06288469


 27%|██▋       | 1335/5000 [00:59<01:56, 31.47it/s]

Epoch: 1300/5000 --> loss: 4.14161015, error: 1.01593637


 29%|██▊       | 1435/5000 [01:03<01:38, 36.08it/s]

Epoch: 1400/5000 --> loss: 3.96455979, error: 0.93565112


 31%|███       | 1535/5000 [01:06<01:35, 36.13it/s]

Epoch: 1500/5000 --> loss: 3.81172442, error: 0.87791556


 33%|███▎      | 1634/5000 [01:11<02:09, 25.90it/s]

Epoch: 1600/5000 --> loss: 3.65430331, error: 0.86402243


 35%|███▍      | 1735/5000 [01:15<01:31, 35.70it/s]

Epoch: 1700/5000 --> loss: 3.53450322, error: 0.85979337


 37%|███▋      | 1835/5000 [01:18<01:30, 35.03it/s]

Epoch: 1800/5000 --> loss: 3.43544507, error: 0.91424441


 39%|███▊      | 1934/5000 [01:22<01:35, 32.14it/s]

Epoch: 1900/5000 --> loss: 3.29237628, error: 0.96662337


 41%|████      | 2035/5000 [01:27<01:25, 34.71it/s]

Epoch: 2000/5000 --> loss: 3.12290549, error: 1.12497818


 42%|████▏     | 2099/5000 [01:30<01:49, 26.44it/s]