In [1]:
import jax.numpy as jnp
from fbpinns.problems import Problem

class Helmholtz100D(Problem):
    """
    100-D Helmholtz benchmark: -(Δu) + α u = f,
    u_exact = ∏ sin(pi x_i),  f = (100π² + α) u_exact,
    homogeneous Dirichlet via tanh multiplier.
    """

    @staticmethod
    def init_params(alpha: float = 10.0, sd: float = 0.1):
        static_params = {
            "dims": (1, 100),  # scalar output over a 100-D domain
            "alpha": alpha,
            "sd": sd,
        }
        return static_params, {}

    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):
        x = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        D = all_params["static"]["dims"][1]
        # request u and all second derivatives ∂²u/∂x_i²
        req = [(0, ())]
        req += [(0, (i, i)) for i in range(D)]
        return [[x, tuple(req)]]

    @staticmethod
    def constraining_fn(all_params, x_batch, net_out):
        sd = all_params["static"]["sd"]
        tanh = jax.nn.tanh

        # smooth boundary multiplier M(x)
        M = jnp.prod(
            tanh(x_batch/sd) * tanh((1 - x_batch)/sd),
            axis=1, keepdims=True
        )

        # manufactured exact solution
        u_ex = jnp.prod(jnp.sin(jnp.pi * x_batch), axis=1, keepdims=True)
        N = net_out[:, 0:1]

        return M * N + u_ex

    @staticmethod
    def loss_fn(all_params, constraints):
        x, u, *second_derivs = constraints[0]
        lap = sum(second_derivs)  # -Δu gives -lap
        α = all_params["static"]["alpha"]
        u_ex = jnp.prod(jnp.sin(jnp.pi * x), axis=1, keepdims=True)
        f = (100 * jnp.pi**2 + α) * u_ex
        r = -lap + α * u - f
        return jnp.mean(r**2), r

    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape=None):
        return jnp.prod(jnp.sin(jnp.pi * x_batch), axis=1, keepdims=True)


In [None]:
import numpy as np

from fbpinns.domains import RectangularDomainND
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN, ChebyshevKAN, ChebyshevAdaptiveKAN, StackedChebyshevKAN, OptimizedStackedChebyshevKAN
from fbpinns.schedulers import LineSchedulerRectangularND
from fbpinns.constants import Constants, get_subdomain_ws
from fbpinns.trainers import FBPINNTrainer
from fbpinns.attention import RBAttention

subdomain_xs=[np.linspace(0,1,1)] * 100
c = Constants(
    domain=RectangularDomainND,
    domain_init_kwargs = dict(
        xmin=np.array([0.]*100),
        xmax=np.array([1.]*100),
        ),
    problem=Helmholtz100D,
    problem_init_kwargs = dict(),
    decomposition=RectangularDecompositionND,
    decomposition_init_kwargs=dict(
        subdomain_xs=subdomain_xs,
        subdomain_ws=get_subdomain_ws(subdomain_xs, 2.9),
        unnorm=(0.,3.),
    ),
    # network=FCN,
    # network_init_kwargs = dict(
    #     layer_sizes = (3, 16, 16, 4),
    # ),
    # scheduler = LineSchedulerRectangularND,
    # scheduler_kwargs = dict(
    #     point=[0.], iaxis=0,
    # ),
    # network=AdaptiveChebyshevKAN,
    # network_init_kwargs=dict(
    #     in_dim=2,
    #     out_dim=1,
    #     degree=6
    # ),
    network=OptimizedStackedChebyshevKAN,
    network_init_kwargs = dict(
        dims=[3, 4, 4],
        degrees=[4, 4]
    ),
    optimiser_kwargs = dict(
        learning_rate=0.001
        ),
    ns=((50,)*100,),
    n_test=(50,)*100,
    n_steps=10000,
    clear_output=False,
    attention_tracker=RBAttention,
    attention_tracking_kwargs=dict(
        eta_lr = 1e-2,
        gamma_decay = 0.99,
        out_dim=1,
        N=125000
        ),
    summary_freq    = 100# outputs train stats to command line
    # test_freq       = 1000# outputs test stats to plot / file / command line
    )

run = FBPINNTrainer(c)
all_params = run.train()