In [30]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [31]:
import numpy as np
from fbpinns.domains import RectangularDomainND

domain = RectangularDomainND
domain_init_kwargs = dict(
    xmin = np.array([0.,]),
    xmax = np.array([24.,])
)

domain_init_kwargs

{'xmin': array([0.]), 'xmax': array([24.])}

In [32]:
import jax.numpy as jnp
from fbpinns.problems import Problem

from scipy.integrate import odeint
from scipy.interpolate import interp1d

class CompetitionModel(Problem):

    @staticmethod
    def init_params(params, u0, v0):
        
        r, a1, a2, b1, b2 = params 
        static_params = {
            "dims":(2,1),   # dims of solution and problem
            "r":r,
            "a1":a1,
            "a2":a2,
            "b1":b1,
            "b2":b2,
            "u0":u0,
            "v0":v0,
        }
        trainable_params = {
        }
        return static_params, trainable_params
    
    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):
        # Physics Loss
        x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        required_ujs_phys = (
            (0, ()),  
            (1, ()), 
            (0, (0,)),  
            (1, (0,)), 
        )


        # Boundary Loss
        x_batch_boundary = jnp.array([0.]).reshape((1,1))
        u0 = all_params["static"]["problem"]["u0"]
        v0 = all_params["static"]["problem"]["v0"]
        boundary = jnp.array([u0, v0]).reshape((2,1))
        required_ujs_boundary = (
            (0, ()), 
            (1, ()),  
        )


        # # Data Loss
        # x_batch_data = jnp.linspace(0,10,15).astype(float).reshape((15,1))
        # r_true, a1_true, a2_true, b1_true, b2_true = [all_params['static']["problem"][key] for key in ('r', 'a1', 'a2', 'b1', 'b2')]
        # params = (r_true, a1_true, a2_true, b1_true, b2_true)
        # solution = odeint(CompetitionModel.model, [u0,v0], x_batch_data.reshape((15,)), args=(params,))
        # u_data = solution[:,0]
        # v_data = solution[:,1]
        # required_ujs_data = (
        #     (0, ()),  # Value of u at data points
        #     (1, ()),  # Value of v at data points
        # )


        return [[x_batch_phys, required_ujs_phys],
                [x_batch_boundary, boundary, required_ujs_boundary]]#,
                # [x_batch_data, u_data, v_data, required_ujs_data]]
    
    @staticmethod
    def loss_fn(all_params, constraints):
        
        r, a1, a2, b1, b2 = [all_params['static']["problem"][key] for key in ('r', 'a1', 'a2', 'b1', 'b2')]

        # Physics loss
        _, u, v, ut, vt = constraints[0]
        phys1 = jnp.mean((ut - u + a1*u**2 + a2*u*v)**2)
        phys2 = jnp.mean((vt - r*v + r*b1*u*v + r*b2*v**2)**2)
        phys = phys1 + phys2

        # Boundary loss
        _, uvc, u, v = constraints[1]
        print(uvc.shape)
        uc, vc = uvc
        boundary = 1e6*jnp.mean((u-uc)**2) + 1e6*jnp.mean((v-vc)**2)

        # # Data Loss
        # _, ud, vd, u, v = constraints[2]
        # print(u.shape, v.shape, ud.shape, vd.shape)
        # u = u.reshape(-1)  # Reshape u to be (15,)
        # v = v.reshape(-1)  # Ensure v is correctly populated with data first, then reshape
        # print(u.shape, v.shape, ud.shape, vd.shape)
        # data = 1e6*jnp.mean((u-ud)**2) + 1e6*jnp.mean((v-vd)**2)


        return phys + boundary #+ data

    @staticmethod
    def model(y, t, params):
        """
        Compute the derivatives of the system at time t.
        
        :param y: Current state of the system [u, v].
        :param t: Current time.
        :param params: Parameters of the model (a1, a2, b1, b2, r).
        :return: Derivatives [du/dt, dv/dt].
        """
        u, v = y  
        a1, a2, b1, b2, r = params  
        
        du_dt = u * (1 - a1 * u - a2 * v)
        dv_dt = r * v * (1 - b1 * u - b2 * v)
        
        return [du_dt, dv_dt]
    
    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape=None):
        r = all_params['static']["problem"]['r']
        a1 = all_params['static']["problem"]['a1']
        a2 = all_params['static']["problem"]['a2']
        b1 = all_params['static']["problem"]['b1']
        b2 = all_params['static']["problem"]['b2']
        u0 = all_params["static"]["problem"]["u0"]
        v0 = all_params["static"]["problem"]["v0"]
        params = [r, a1, a2, b1, b2]
       
        t = jnp.arange(0, 25.02, 0.02)  
        
        # Solve the system 
        solution = odeint(CompetitionModel.model, [u0, v0], t, args=(params,))
        
        # Interpolation 
        u_interp = interp1d(t, solution[:, 0], kind='cubic')
        v_interp = interp1d(t, solution[:, 1], kind='cubic')
        
        u_data = u_interp(x_batch.flatten())
        v_data = v_interp(x_batch.flatten())
        
        # Combine 
        combined_solution = jnp.vstack((u_data, v_data)).T
        if batch_shape:
            combined_solution = combined_solution.reshape(batch_shape + (2,))
        
        return combined_solution

    
problem = CompetitionModel
params = [0.5, 0.3, 0.6, 0.7, 0.3]
problem_init_kwargs = dict(
    params=params, u0=2, v0=1,
)


In [33]:
from fbpinns.decompositions import RectangularDecompositionND

decomposition = RectangularDecompositionND # use a rectangular domain decomposition

decomposition_init_kwargs = dict(
    subdomain_xs = [np.linspace(0,24,15)],
    subdomain_ws = [(24/12)*np.ones((15,))],
    unnorm = (0., 1.),
)
decomposition_init_kwargs

{'subdomain_xs': [array([ 0.        ,  1.71428571,  3.42857143,  5.14285714,  6.85714286,
          8.57142857, 10.28571429, 12.        , 13.71428571, 15.42857143,
         17.14285714, 18.85714286, 20.57142857, 22.28571429, 24.        ])],
 'subdomain_ws': [array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])],
 'unnorm': (0.0, 1.0)}

In [34]:
from fbpinns.networks import FCN

network = FCN# place a fully-connected network in each subdomain
network_init_kwargs=dict(
    layer_sizes=[1,32,2],# with 2 hidden layers
)

In [35]:
from fbpinns.constants import Constants

c = Constants(
    domain=domain,
    domain_init_kwargs=domain_init_kwargs,
    problem=problem,
    problem_init_kwargs=problem_init_kwargs,
    decomposition=decomposition,
    decomposition_init_kwargs=decomposition_init_kwargs,
    network=network,
    network_init_kwargs=network_init_kwargs,
    ns=((2000,),),# use 200 collocation points for training
    n_test=(500,),# use 500 points for testing
    n_steps=20000,# number of training steps
    clear_output=True,
)

print(c)

<fbpinns.constants.Constants object at 0x000002755892D3A0>
run: test
domain: <class 'fbpinns.domains.RectangularDomainND'>
domain_init_kwargs: {'xmin': array([0.]), 'xmax': array([24.])}
problem: <class '__main__.CompetitionModel'>
problem_init_kwargs: {'params': [0.5, 0.3, 0.6, 0.7, 0.3], 'u0': 2, 'v0': 1}
decomposition: <class 'fbpinns.decompositions.RectangularDecompositionND'>
decomposition_init_kwargs: {'subdomain_xs': [array([ 0.        ,  1.71428571,  3.42857143,  5.14285714,  6.85714286,
        8.57142857, 10.28571429, 12.        , 13.71428571, 15.42857143,
       17.14285714, 18.85714286, 20.57142857, 22.28571429, 24.        ])], 'subdomain_ws': [array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])], 'unnorm': (0.0, 1.0)}
network: <class 'fbpinns.networks.FCN'>
network_init_kwargs: {'layer_sizes': [1, 32, 2]}
n_steps: 20000
scheduler: <class 'fbpinns.schedulers.AllActiveSchedulerND'>
scheduler_kwargs: {}
ns: ((2000,),)
n_test: (500,)
sampler: grid
optimiser: <f

In [36]:
from fbpinns.trainers import FBPINNTrainer

run = FBPINNTrainer(c)
all_params = run.train()

[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707909659.TS'
[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707909900.TS'
[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707909921.TS'
[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707910046.TS'
[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707910111.TS'
[WinError 32] The process cannot access the file because it is being used by another process: 'results/summaries/test/events.out.tfevents.1707910287.TS'
[WinError 32] The process cannot access the file because it is being used by anoth

AssertionError: 