In [1]:
import jax
import jax.numpy as jnp
from fbpinns.problems import Problem

In [None]:
class HeatEquation1D(Problem):
    """
    Solves the 1D heat equation:
    
        u_t = α u_xx

    on the domain x ∈ [0, 1] and t ∈ [0, T],
    with homogeneous Dirichlet boundary conditions:
        u(0,t) = u(1,t) = 0
    and initial condition:
        u(x,0) = sin(πx)

    The analytical solution is:
        u(x,t) = sin(πx) * exp(-α π² t)
    """

    @staticmethod
    def init_params(alpha=1.0, N=40000):
        static_params = {
            "dims": (1, 2), 
            "alpha": alpha,
        }
        trainable_params = {
            'attention': jnp.zeros((N, 1))
            }
        return static_params, trainable_params

    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):
        
        x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        
        required_ujs_phys = (
            (0, (0, 0)),  # u_xx
            (0, (1,)),    # u_t
        )
        
        return [[x_batch_phys, required_ujs_phys]]

    @staticmethod
    def constraining_fn(all_params, x_batch, u):
        """
        Enforces the Dirichlet BCs and the initial condition.
        Assumes x_batch has two columns: x and t.
        Constructs the solution as:
            u(x,t) = x*(1-x)*t * u + sin(πx)
        so that:
            u(x,0) = sin(πx),
            u(0,t) = 0, and
            u(1,t) = 0.
        """
        x = x_batch[:, 0:1]
        t = x_batch[:, 1:2]
        tanh = jax.nn.tanh
        sd = 0.1
        return tanh((-x)/sd)*tanh((1-x)/sd)*tanh((t)/sd) * u + jnp.sin(jnp.pi * x)

    @staticmethod
    def loss_fn(all_params, constraints):
        _, uxx, ut = constraints[0]
        alpha = all_params["static"]["problem"]["alpha"]
        
        residual = ut - alpha * uxx
        
        selected = all_params["trainable"]["problem"]["selected"].astype(jnp.int32)
        attention = all_params["trainable"]["problem"]["attention"][selected]  # (N,1)
        current_i = all_params["trainable"]["problem"]["current_i"]

        m1 = jnp.mean(residual**2)
        m2 = jnp.mean(((jnp.exp(-current_i)+attention)*residual)**2)
        jax.debug.print("curr_i = {i}, raw MSE = {m1:.6f}, weighted MSE = {m2:.6f}", i=current_i, m1=m1, m2=m2)
        jax.debug.print("residual max = {a}, attention head = {b}", a=jnp.max(jnp.abs(residual)), b=attention[:5, 0])
        
        return jnp.mean(
            # ((jnp.exp(-current_i)+attention)*residual) ** 2
            residual ** 2
            ), residual

    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape=None):
        alpha = all_params["static"]["problem"]["alpha"]
        x = x_batch[:, 0:1]
        t = x_batch[:, 1:2]
        return jnp.sin(jnp.pi * x) * jnp.exp(-alpha * (jnp.pi ** 2) * t)
    

In [3]:
import numpy as np
from fbpinns.domains import RectangularDomainND
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN, ChebyshevKAN
from fbpinns.schedulers import LineSchedulerRectangularND
from fbpinns.constants import Constants, get_subdomain_ws
from fbpinns.trainers import FBPINNTrainer

# Set the final time for the simulation
T = 1.0
subdomain_xs=[np.linspace(0,1,5), np.linspace(0,1,5)]

# Create a Constants object to hold all hyperparameters
c = Constants(
    # Define the problem domain (x and t)
    domain=RectangularDomainND,
    domain_init_kwargs=dict(
        xmin=np.array([0.0, 0.0]),  # x in [0, 1] and t in [0, T]
        xmax=np.array([1.0, T])
    ),
    # Set the problem to our heat equation
    problem=HeatEquation1D,
    problem_init_kwargs=dict(
        alpha=1.0,
        N=100*100
        ),
    # Use a rectangular domain decomposition
    decomposition=RectangularDecompositionND,
    decomposition_init_kwargs=dict(
        # Split both x and t into 11 subdomains (you can adjust this as needed)
        subdomain_xs=subdomain_xs,
        subdomain_ws=get_subdomain_ws(subdomain_xs, 2.9),
        unnorm=(0, 1)
    ),
    # Use a fully-connected network; note the input layer size is 2 (x and t)
    # network=FCN,
    # network_init_kwargs=dict(
    #     layer_sizes=(2, 32, 1)
    # ),
    network=ChebyshevKAN,# place a fully-connected network in each subdomain
    network_init_kwargs=dict(
        input_dim=2,
        output_dim=1,
        degree=9
    ),
    scheduler = LineSchedulerRectangularND,
    scheduler_kwargs = dict(
        point=[0.], iaxis=0,
    ),
    # Number of training collocation points and testing points
    # ns=((200,),),
    # n_test=(500,),
    ns=((100,100),),
    n_test=(100,100),
    # Set training steps and optimizer parameters
    n_steps=10000,
    # optimiser_kwargs=dict(learning_rate=1e-3),
    clear_output=False,
    attention_tracking=True,
)

# Create the trainer and start training
run = FBPINNTrainer(c)
all_params = run.train()

KeyboardInterrupt: 

In [None]:
all_params

{'static': {'domain': {'xd': 2,
   'xmin': Array([0., 0.], dtype=float32),
   'xmax': Array([1., 1.], dtype=float32)},
  'problem': {'dims': (1, 2), 'alpha': 1.0},
  'decomposition': {'m': 25,
   'xd': 2,
   'subdomain': {'params': [Array([[-0.3625, -0.3625],
            [-0.3625, -0.1125],
            [-0.3625,  0.1375],
            [-0.3625,  0.3875],
            [-0.3625,  0.6375],
            [-0.1125, -0.3625],
            [-0.1125, -0.1125],
            [-0.1125,  0.1375],
            [-0.1125,  0.3875],
            [-0.1125,  0.6375],
            [ 0.1375, -0.3625],
            [ 0.1375, -0.1125],
            [ 0.1375,  0.1375],
            [ 0.1375,  0.3875],
            [ 0.1375,  0.6375],
            [ 0.3875, -0.3625],
            [ 0.3875, -0.1125],
            [ 0.3875,  0.1375],
            [ 0.3875,  0.3875],
            [ 0.3875,  0.6375],
            [ 0.6375, -0.3625],
            [ 0.6375, -0.1125],
            [ 0.6375,  0.1375],
            [ 0.6375,  0.3875],
    