In [None]:
import jax.numpy as jnp
import jax
import numpy as np
from fbpinns.problems import Problem

class HarmonicOscillator1D_MultiFreq(Problem):
    """Solves the time-dependent damped harmonic oscillator using hard boundary conditions
          d^2 u      du
        m ----- + mu -- + ku = 4cos(w0*t) + 40cos(w1*t)
          dt^2       dt

        Boundary conditions:
        u (0) = 0
        u'(0) = 0
    """

    @staticmethod
    def init_params(m=0, mu=1, k=0, w0=40, w1=40, sd=0.1):

        static_params = {
            "dims":(1,1),
            "m":m,
            "mu":mu,
            "k":k,
            "sd":sd,
            "w0":w0,
            "w1":w1,
            }

        return static_params, {}

    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):

        x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        required_ujs_phys = (
            (0,()),
            (0,(0,)),
            (0,(0,0))
        )
        return [[x_batch_phys, required_ujs_phys],]

    @staticmethod
    def constraining_fn(all_params, x_batch, u):

        sd = all_params["static"]["problem"]["sd"]
        x, tanh = x_batch[:,0:1], jnp.tanh

        # u = (x/sd)**2 * (tanh(x/sd)**2) * u
        u = (tanh(x/sd)**2) * u
        return u

    @staticmethod
    def loss_fn(all_params, constraints):

        m, mu, k = all_params["static"]["problem"]["m"], all_params["static"]["problem"]["mu"], all_params["static"]["problem"]["k"]
        w0, w1 = all_params["static"]["problem"]["w0"], all_params["static"]["problem"]["w1"]
        x_batch, u, ut, utt = constraints[0]
        t = x_batch[:,0:1]

        phys = m*utt + mu*ut + k*u - (w0*jnp.cos(w0*t) + w1*jnp.cos(w1*t))
        loss = jnp.mean((phys)**2)
        return loss, phys
    
    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape=None):
        x, sin = x_batch[:,0:1], jnp.sin
        w0, w1 = all_params["static"]["problem"]["w0"], all_params["static"]["problem"]["w1"]
        u = sin(w0*x) + sin(w1*x)
        return u

In [1]:
import numpy as np

from fbpinns.domains import RectangularDomainND
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN, ChebyshevKAN
from fbpinns.constants import Constants
from fbpinns.trainers import FBPINNTrainer, PINNTrainer


c = Constants(
    domain=RectangularDomainND,
    domain_init_kwargs=dict(
        xmin=np.array([0,]),
        xmax=np.array([5,]),
    ),
    problem=HarmonicOscillator1D_MultiFreq,
    problem_init_kwargs=dict(
        m=0,
        mu=1,
        k=0,
        w0=4,
        w1=40,
        sd=0.1
    ),
    decomposition=RectangularDecompositionND,
    decomposition_init_kwargs=dict(
        subdomain_xs=[np.linspace(0,5,30)],
        subdomain_ws=[0.5*np.ones((30,))],
        unnorm=(0.,3.),
    ),
    network=FCN,
    network_init_kwargs=dict(
        layer_sizes=[1,32,1],
    ),
    # network=ChebyshevKAN,
    # network_init_kwargs=dict(
    #     input_dim=1,
    #     output_dim=1,
    #     degree=20
    # ),
    optimiser_kwargs = dict(
        learning_rate=0.01
    ),
    ns=((1000,),),
    n_test=(1000,),
    n_steps=50000,
    clear_output=True,
    attention_tracking_kwargs=dict(
        eta_lr = 1e-2,
        gamma_decay = 0.99,
        out_dim=1,
        N=1000
    ),
)

# run = FBPINNTrainer(c)
run = PINNTrainer(c)
all_params = run.train()

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x74b07574d670>>
Traceback (most recent call last):
  File "/vol/bitbucket/ss7921/dlenv/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


NameError: name 'HarmonicOscillator1D_MultiFreq' is not defined