In [1]:
import tensorflow as tf
import numpy as np
import nbimporter
from PDE_Classes import *

tf.keras.backend.set_floatx('float64')


# A general class for constructing neural network architecture
class PINN_Architecture(tf.keras.Model):
    def __init__(self, x0, x1, 
            num_hidden_layers=2, 
            num_neurons_per_layer=20,
            output_dim=1,
            activation=tf.keras.activations.swish,
            kernel_initializer='glorot_normal',
            **kwargs):
        
        # Intialize superclass with its default parameter signature
        super().__init__(**kwargs)
        
        # Store hyperparameters
        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.x0 = x0
        self.x1 = x1
        
        # Define NN architecture
        self.scale = tf.keras.layers.Lambda(
            lambda x: (x - x0)/(x1 - x0))
        self.hidden = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers)]
        self.out = tf.keras.layers.Dense(output_dim)
        
    
    # Mimic functionality of model(x)
    def call(self, X):
        #Forward-pass through neural network.
        Z = self.scale(X)
        for i in range(self.num_hidden_layers):
            Z = self.hidden[i](Z)
        return self.out(Z)



In [None]:
# A general class for FD models on subdomains
class FD_1D_Steady():
    def __init__(self, X, BC, pde):

        self.X = X
        n_FD = len(X)
        xl = X[0]
        xr = X[-1]
        
        h = X[1]-X[0]
        
        nu = pde.nu
        beta = pde.beta
        order = pde.order
        
        a = - nu/(h**2)
        b = (2*nu)/(h**2)
        c = -(nu/(h**2))

        if order == 1:
            b += beta / h
            c += -beta / h
            d = 0.0
        elif order == 2:
            b += 3 / 2 * beta / h
            c += -2 * beta / h
            d = 1 / 2 * beta / h
        else:
            raise ValueError(f"Invalid order: {order}")
        
        A = np.diagflat([b]*(n_FD)) + np.diagflat([c]*(n_FD - 1), -1) + np.diagflat([a]*(n_FD - 1), 1) + np.diagflat([d] * (n_FD - 2), -2)

        if xr == BC[1]:
            y = np.ones((A.shape[0]-1, 1)) 
            y[0] += -d*np.random.rand(1) - c*np.random.rand(1)
            y[1] += -d*np.random.rand(1)

            u_FD = np.linalg.solve( A[:-1, :-1], y )
            self.u = np.hstack( (u_FD.flatten(), pde.f(xr)) )

        elif xl == BC[0]:
            y = np.ones((A.shape[0]-1, 1))
            y[-1] += -a*np.random.rand(1)

            u_FD = np.linalg.solve( A[:-1, :-1], y )
            self.u = np.hstack( (pde.f(xl), u_FD.flatten()) )

        else:
            y = np.ones((A.shape[0], 1))
            y[0] += -d*np.random.rand(1) - c*np.random.rand(1)
            y[1] += -d*np.random.rand(1)
            y[-1] += -a*np.random.rand(1)

            u_FD = np.linalg.solve( A, y )
            self.u = np.squeeze(u_FD)
            
        self.A = A
        self.coeff = (a, b, c, d)
        
#     def update_u(self, U):
#         self.u = U
    
    # Mimic functionality of model(x)
    def __call__(self, x):
        return np.interp(x, self.X, self.u)
        

In [None]:
### Multiple domain Schwarz Coupling of a PINN with a finite difference model using both SDBCs and WDBCs

# Shell function used for dynamic class inheritance of PDEs
def PINN_Schwarz_Steady(model_r, model_i, X_r, X_b, alpha, pde, strong, snap=0, **kwargs):
    
    
    class PINN_Solver_Schwarz(pde):
        def __init__(self, model_r, model_i, X_r, X_b, alpha, strong, snap, **kwargs):
            
            # Intialize dynamic superclass with its default parameter signature
            super().__init__(**kwargs)
            
            # Store models 
            self.model_r = model_r
            self.model_i = model_i

            # Store internal collocation points
            self.x = X_r

            # Store boundary points
            self.xb = X_b
            
            # Store snapshot points if applicable
            if snap:
                self.xs = tf.constant(np.linspace(float(self.xb[0][0][0]), float(self.xb[1][0][0]), num=snap, 
                                    endpoint=False)[1:], shape=(snap-1, 1), dtype='float64')

            # Store loss scaling coefficient
            self.a = alpha
            
            self.loss = 0
            self.err = 0
    
    
        def BC_enforce(self, x):
            return (tf.math.tanh( 5*(1-x) )*tf.math.tanh( x ))
    
    
        def get_residual(self, x):

            with tf.GradientTape(persistent=True) as tape:
                # Watch variable x during this GradientTape
                tape.watch(x)
                
                # Compute current values u(x) with strongly enforced BCs
                if strong:
                    u = self.BC_enforce(x)*self.model_r(x)
                else:
                    u = self.model_r(x)
                
                # Store first derivative
                u_x = tape.gradient(u, x)
            
            # Store second derivative 
            u_xx = tape.gradient(u_x, x)
            del tape

            return self.f_r(u_x, u_xx)
        

        def loss_strong(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))

            # Initialize loss with residual loss function
            loss = phi_r
            
            phi_i = 0
            for i,model in enumerate(self.model_i):
                if not model:
                    continue
                
                b = self.xb[i]
                
                # Calculate interface loss for current model if applicable
                u_pred1 = self.BC_enforce(b)*self.model_r(b)
                if isinstance(model[0], FD_1D_Steady):
                    u_pred2 = model[0](b)
                else:
                    u_pred2 = self.BC_enforce(b)*model[0](b)   
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - u_pred2))
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = (1 - self.a) * tf.reduce_mean(tf.square( self.BC_enforce(self.xs)*self.model_r(self.xs) 
                                                                - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_i + phi_s
            
            return loss, phi_r, phi_i, phi_s
        
        
        def loss_weak(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))
            
            # Initialize loss with residual loss function
            loss = phi_r
            
            phi_b = 0
            phi_i = 0
            for i,model in enumerate(self.model_i):
                
                b = self.xb[i]
                
                # Calculate boundary loss for current model if applicable
                if not model:
                    u_pred = self.model_r(b)
                    phi_b += (1 - self.a) * tf.reduce_mean(tf.square(self.f_b(b) - u_pred))
                    continue
                
                # Calculate interface loss for current model if applicable
                u_pred1 = self.model_r(b)
                u_pred2 = model[0](b)
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - u_pred2))
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = self.a * tf.reduce_mean(tf.square( self.model_r(self.xs) - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_b + phi_i + phi_s
            
            return loss, phi_r, phi_b, phi_i, phi_s
        

        def get_gradient(self, x):
            with tf.GradientTape(persistent=True) as tape:
                # This tape is for derivatives with respect to trainable variables
                tape.watch(self.model_r.trainable_variables)
                if strong:
                    loss, _, _, _ = self.loss_strong(x)
                else:
                    loss, _, _, _, _ = self.loss_weak(x)

            g = tape.gradient(loss, self.model_r.trainable_variables)
            del tape

            return g
        
        
        
        def FD_update(self):
            
            a, b, c, d = self.model_r.coeff
            model = self.model_i
            A = self.model_r.A
            
            if (model[0] and model[1]):
                f_NN = np.ones((A.shape[0], 1))
                
                if (strong and not isinstance(model[0][0], FD_1D_Steady)):
                    f_NN[0] = f_NN[0] + ( -d*self.BC_enforce(self.x[0])*model[0][0](tf.reshape(self.x[0], shape=(1,1))) 
                                 -c*self.BC_enforce(self.x[1])*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    f_NN[1] = f_NN[1] + ( -d*self.BC_enforce(self.x[1])*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                      
                else:    
                    f_NN[0] = f_NN[0] + ( -d*model[0][0](tf.reshape(self.x[0], shape=(1,1))) 
                                 -c*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    f_NN[1] = f_NN[1] + ( -d*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    
                if (strong and not isinstance(model[1][0], FD_1D_Steady)):
                    f_NN[-1] = f_NN[-1] + ( -a*self.BC_enforce(self.x[-1])*model[1][0]( tf.reshape(self.x[-1], shape=(1,1))) )
                else:    
                    f_NN[-1] = f_NN[-1] + ( -a*model[1][0](tf.reshape(self.x[-1], shape=(1,1))) )
                    
                u_FD = np.squeeze( np.linalg.solve( A, f_NN ) )
 
            elif model[1]:
                f_NN = np.ones((A.shape[0]-1, 1))
                A = A[:-1, :-1]
                
                if (strong and not isinstance(model[1][0], FD_1D_Steady)):
                    f_NN[-1] = f_NN[-1] + ( -a*self.BC_enforce(self.x[-1])*model[1][0](tf.reshape(self.x[-1], shape=(1,1))) )
                else:    
                    f_NN[-1] = f_NN[-1] + ( -a*model[1][0](tf.reshape(self.x[-1], shape=(1,1))) )

                u_FD = np.linalg.solve( A, f_NN )
                
                u_FD = np.hstack((self.f(self.x[0]), u_FD.flatten()))
            
            elif model[0]:
                f_NN = np.ones((A.shape[0]-1, 1))
                A = A[:-1, :-1]
                
                if (strong and not isinstance(model[0][0], FD_1D_Steady)):
                    f_NN[0] = f_NN[0] + ( -d*self.BC_enforce(self.x[0])*model[0][0](tf.reshape(self.x[0], shape=(1,1))) 
                                 -c*self.BC_enforce(self.x[1])*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    f_NN[1] = f_NN[1] + ( -d*self.BC_enforce(self.x[1])*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                      
                else:    
                    f_NN[0] = f_NN[0] + ( -d*model[0][0](tf.reshape(self.x[0], shape=(1,1))) 
                                 -c*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    f_NN[1] = f_NN[1] + ( -d*model[0][0](tf.reshape(self.x[1], shape=(1,1))) )
                    
                u_FD = np.linalg.solve( A, f_NN )
                
                u_FD = np.hstack((u_FD.flatten(), self.f(self.x[-1])))
                    
            else:

                u_FD = np.squeeze(self.f(self.x))
            
            # Update u for current model
            self.model_r.u = u_FD
        
        
        
        def solve(self, optimizer, batch_size, numEpochs):

            @tf.function
            def train_step(x):
                # Retrieve loss gradient w.r.t. trainable variables
                grad_theta = self.get_gradient(x)

                # Perform gradient descent step
                optimizer.apply_gradients(zip(grad_theta, self.model_r.trainable_variables))
            
            # Split data into training batches
            train_dataset = tf.data.Dataset.from_tensor_slices((self.x,))
            train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
            
            # If current model is FOM, update interface boundaries with adjacent NN models
            if isinstance(self.model_r, FD_1D_Steady):
                self.FD_update()
                self.err = np.square(self.model_r.u - self.f(self.x)).mean()
            else: 
                # Iterate training
                for i in range(numEpochs):

                    # Train on each batch
                    for (x_batch_train,) in train_dataset:
                        train_step(x_batch_train)

                    # Compute loss for full dataset to track training progress
                    if strong:
                        self.loss, self.phi_r, self.phi_i, self.phi_s = self.loss_strong(self.x)
                    else:
                        self.loss, self.phi_r, self.phi_b, self.phi_i, self.phi_s = self.loss_weak(self.x)
    
    # Return intialized class instance
    return PINN_Solver_Schwarz(model_r, model_i, X_r, X_b, alpha, strong, snap, **kwargs)

In [3]:
### Single domain PINN for steady state strong form PDEs

# Shell function used for dynamic class inheritance of PDEs
def PINN_Solver_Steady(model, X_r, X_b, U_b, alpha, pde, **kwargs):

    class PINN_Solver(pde):
        def __init__(self, model, X_r, X_b, U_b, alpha, **kwargs):
            
            # Intialize dynamic superclass with its default parameter signature
            super().__init__(**kwargs)
            
            # Store model 
            self.model = model

            # Store internal collocation points
            self.x = X_r

            # Store boundary points
            self.xb = X_b

            # Use PDE to get boundary condition data
            self.ub = U_b

            # Store loss scaling coefficient
            self.a = alpha

            # Initialize history of losses and global iteration counter
            self.hist = []
            self.iter = 0

        def get_residual(self, x):

            with tf.GradientTape(persistent=True) as tape:
                # Watch variable x during this GradientTape
                tape.watch(x)

                # Compute current values u(x)
                u = self.model(x)
                
                # Store first derivative
                u_x = tape.gradient(u, x)
            
            # Store second derivative 
            u_xx = tape.gradient(u_x, x)
            del tape

            return self.f_r(u_x, u_xx)

        def loss_function(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = tf.reduce_mean(tf.square(r))

            # Initialize loss
            loss = self.a * phi_r

            # Add phi_b to the loss
            for i in range(2):
                u_pred = self.model(self.xb[i])
                loss += (1 - self.a) * tf.reduce_mean(tf.square(self.ub[i] - u_pred))

            return loss

        def get_gradient(self, x):
            with tf.GradientTape(persistent=True) as tape:
                # This tape is for derivatives with respect to trainable variables
                tape.watch(self.model.trainable_variables)
                loss = self.loss_function(x)

            g = tape.gradient(loss, self.model.trainable_variables)
            del tape

            return g


        def solve(self, optimizer, batch_size, numEpochs):

            @tf.function
            def train_step(x):
                # Retrieve loss gradient w.r.t. trainable variables
                grad_theta = self.get_gradient(x)

                # Perform gradient descent step
                optimizer.apply_gradients(zip(grad_theta, self.model.trainable_variables))
            
            # Split data into training batches
            train_dataset = tf.data.Dataset.from_tensor_slices((self.x,))
            train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)
            
            # Iterate training
            for i in range(numEpochs):
                
                # Train on each batch
                for step, (x_batch_train,) in enumerate(train_dataset):
                    train_step(x_batch_train)
                
                # Compute loss for full dataset to track training progress
                loss = self.loss_function(self.x)
                
                # Append current loss to history
                self.hist.append(loss.numpy())

                # Output current loss after 2^5 epochs
                if i%512 == 0:
                    print('Epoch {:5d}: loss = {:10.8e}'.format(i,loss))
    
    # Return intialized class instance
    return PINN_Solver(model, X_r, X_b, U_b, alpha, **kwargs)


In [None]:
### Multiple domain Schwarz Decomposition of PINNs for steady state strong form PDEs

# Shell function used for dynamic class inheritance of PDEs
def PINN_Solver_Schwarz_Steady(model_r, model_i, X_r, X_b, alpha, pde, snap=0, **kwargs):
    
    class PINN_Solver_Schwarz(pde):
        def __init__(self, model_r, model_i, X_r, X_b, alpha, snap, **kwargs):
            
            # Intialize dynamic superclass with its default parameter signature
            super().__init__(**kwargs)
            
            # Store models 
            self.model_r = model_r
            self.model_i = model_i

            # Store internal collocation points
            self.x = X_r

            # Store boundary points
            self.xb = X_b
            
            # Store snapshot points if applicable
            if snap:
                self.xs = tf.constant(np.linspace(float(self.xb[0][0][0]), float(self.xb[1][0][0]), num=snap, 
                                    endpoint=False)[1:], shape=(snap-1, 1), dtype='float64')

            # Store loss scaling coefficient
            self.a = alpha

            # Initialize history of losses and global iteration counter
#             self.hist = []
#             self.iter = 0
            
        
        def get_residual(self, x):

            with tf.GradientTape(persistent=True) as tape:
                # Watch variable x during this GradientTape
                tape.watch(x)

                # Compute current values u(x)
                u = self.model_r(x)
                
                # Store first derivative
                u_x = tape.gradient(u, x)
            
            # Store second derivative 
            u_xx = tape.gradient(u_x, x)
            del tape

            return self.f_r(u_x, u_xx)
        

        def loss_function(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))

            # Initialize loss with residual loss function
            loss = phi_r
            
            i=0
            phi_b = 0
            phi_i = 0
            for b,y in self.xb:
                # Calculate boundary loss for current model if applicable
                if y:
                    u_pred = self.model_r(b)
                    phi_b += (1 - self.a) * tf.reduce_mean(tf.square(self.f_b(b) - u_pred))
                    continue
            
                # Calculate interface loss for current model if applicable
                u_pred1 = self.model_r(b)
                u_pred2 = self.model_i[i](b)
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - u_pred2))
                i += 1
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = self.a * tf.reduce_mean(tf.square( self.model_r(self.xs) - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_b + phi_i + phi_s

                
            return loss, phi_r, phi_b, phi_i, phi_s
        

        def get_gradient(self, x):
            with tf.GradientTape(persistent=True) as tape:
                # This tape is for derivatives with respect to trainable variables
                tape.watch(self.model_r.trainable_variables)
                loss, _, _, _, _ = self.loss_function(x)

            g = tape.gradient(loss, self.model_r.trainable_variables)
            del tape

            return g


        def solve(self, optimizer, batch_size, numEpochs):

            @tf.function
            def train_step(x):
                # Retrieve loss gradient w.r.t. trainable variables
                grad_theta = self.get_gradient(x)

                # Perform gradient descent step
                optimizer.apply_gradients(zip(grad_theta, self.model_r.trainable_variables))
            
            # Split data into training batches
            train_dataset = tf.data.Dataset.from_tensor_slices((self.x,))
            train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
            
            # Iterate training
            for i in range(numEpochs):
                
                # Train on each batch
                for (x_batch_train,) in train_dataset:
                    train_step(x_batch_train)
                
                # Compute loss for full dataset to track training progress
                self.loss, self.phi_r, self.phi_b, self.phi_i, self.phi_s = self.loss_function(self.x)
                
    
    # Return intialized class instance
    return PINN_Solver_Schwarz(model_r, model_i, X_r, X_b, alpha, snap, **kwargs)

In [None]:
### Multiple domain Schwarz Decomposition of PINNs with strongly enforced dirichlet boundary conditions

# Shell function used for dynamic class inheritance of PDEs
def PINN_SDBC_Schwarz_Steady(model_r, model_i, X_r, X_b, alpha, pde, snap=0, **kwargs):
    
    class PINN_Solver_Schwarz(pde):
        def __init__(self, model_r, model_i, X_r, X_b, alpha, snap, **kwargs):
            
            # Intialize dynamic superclass with its default parameter signature
            super().__init__(**kwargs)
            
            # Store models 
            self.model_r = model_r
            self.model_i = model_i

            # Store internal collocation points
            self.x = X_r

            # Store boundary points
            self.xb = X_b
            
            # Store multiplier for BC enforcement function
            self.m = 5
            
            # Store snapshot points if applicable
            if snap:
                self.xs = tf.constant(np.linspace(float(self.xb[0][0][0]), float(self.xb[1][0][0]), num=snap, 
                                    endpoint=False)[1:], shape=(snap-1, 1), dtype='float64')

            # Store loss scaling coefficient
            self.a = alpha
        
        def BC_enforce(self, x):
            return (tf.math.tanh( self.m*(1-x) )*tf.math.tanh( x ))
        
        def get_residual(self, x):

            with tf.GradientTape(persistent=True) as tape:
                # Watch variable x during this GradientTape
                tape.watch(x)
                
                # Compute current values u(x) with strongly enforced BCs
                u = self.BC_enforce(x)*self.model_r(x)
                
                # Store first derivative
                u_x = tape.gradient(u, x)
            
            # Store second derivative 
            u_xx = tape.gradient(u_x, x)
            del tape

            return self.f_r(u_x, u_xx)
        

        def loss_function(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))

            # Initialize loss with residual loss function
            loss = phi_r
                
            i=0
            phi_i = 0
            for b,y in self.xb:
                if y:
                    continue
            
                # Calculate interface loss for current model if applicable
                u_pred1 = self.BC_enforce(b)*self.model_r(b)
                u_pred2 = self.BC_enforce(b)*self.model_i[i](b)
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - u_pred2))
                i += 1
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = (1 - self.a) * tf.reduce_mean(tf.square( self.BC_enforce(self.xs)*self.model_r(self.xs) - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_i + phi_s

                
            return loss, phi_r, phi_i, phi_s
        

        def get_gradient(self, x):
            with tf.GradientTape(persistent=True) as tape:
                # This tape is for derivatives with respect to trainable variables
                tape.watch(self.model_r.trainable_variables)
                loss, _, _, _ = self.loss_function(x)

            g = tape.gradient(loss, self.model_r.trainable_variables)
            del tape

            return g


        def solve(self, optimizer, batch_size, numEpochs):

            @tf.function
            def train_step(x):
                # Retrieve loss gradient w.r.t. trainable variables
                grad_theta = self.get_gradient(x)

                # Perform gradient descent step
                optimizer.apply_gradients(zip(grad_theta, self.model_r.trainable_variables))
            
            # Split data into training batches
            train_dataset = tf.data.Dataset.from_tensor_slices((self.x,))
            train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
            
            # Iterate training
            for i in range(numEpochs):
                
                # Train on each batch
                for (x_batch_train,) in train_dataset:
                    train_step(x_batch_train)
                
                # Compute loss for full dataset to track training progress
                self.loss, self.phi_r, self.phi_i, self.phi_s = self.loss_function(self.x)
                
    
    # Return intialized class instance
    return PINN_Solver_Schwarz(model_r, model_i, X_r, X_b, alpha, snap, **kwargs)

In [None]:
### Multiple domain Schwarz Coupling of a PINN with a finite difference model using both SDBCs and WDBCs

# Shell function used for dynamic class inheritance of PDEs
def FD_PINN_Schwarz_Steady(model_r, u_int, X_r, X_b, alpha, pde, strong, snap=0, **kwargs):
    
    class PINN_Solver_Schwarz(pde):
        def __init__(self, model_r, u_int, X_r, X_b, alpha, strong, snap, **kwargs):
            
            # Intialize dynamic superclass with its default parameter signature
            super().__init__(**kwargs)
            
            # Store models 
            self.model_r = model_r
            self.u = u_int

            # Store internal collocation points
            self.x = X_r

            # Store boundary points
            self.xb = X_b
            
            # Store multiplier for BC enforcement function
            self.m = 5
            
            # Store snapshot points if applicable
            if snap:
                self.xs = tf.constant(np.linspace(float(self.xb[0][0][0]), float(self.xb[1][0][0]), num=snap, 
                                    endpoint=False)[1:], shape=(snap-1, 1), dtype='float64')

            # Store loss scaling coefficient
            self.a = alpha
        
        def BC_enforce(self, x):
            return (tf.math.tanh( self.m*(1-x) )*tf.math.tanh( x ))
        
        def get_residual(self, x):

            with tf.GradientTape(persistent=True) as tape:
                # Watch variable x during this GradientTape
                tape.watch(x)
                
                # Compute current values u(x) with strongly enforced BCs
                if strong:
                    u = self.BC_enforce(x)*self.model_r(x)
                else:
                    u = self.model_r(x)
                
                # Store first derivative
                u_x = tape.gradient(u, x)
            
            # Store second derivative 
            u_xx = tape.gradient(u_x, x)
            del tape

            return self.f_r(u_x, u_xx)
        

        def loss_strong(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))

            # Initialize loss with residual loss function
            loss = phi_r
                
            i=0
            phi_i = 0
            for b,y in self.xb:
                if y:
                    continue
            
                # Calculate interface loss for current model if applicable
                u_pred1 = self.BC_enforce(b)*self.model_r(b)
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - self.u))
                i += 1
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = (1 - self.a) * tf.reduce_mean(tf.square( self.BC_enforce(self.xs)*self.model_r(self.xs) - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_i + phi_s
            
            return loss, phi_r, phi_i, phi_s
        
        def loss_weak(self, x):

            # Compute phi_r
            r = self.get_residual(x)
            phi_r = self.a * tf.reduce_mean(tf.square(r))

            # Initialize loss with residual loss function
            loss = phi_r
            
            i=0
            phi_b = 0
            phi_i = 0
            for b,y in self.xb:
                # Calculate boundary loss for current model if applicable
                if y:
                    u_pred = self.model_r(b)
                    phi_b += (1 - self.a) * tf.reduce_mean(tf.square(self.f_b(b) - u_pred))
                    continue
            
                # Calculate interface loss for current model if applicable
                u_pred1 = self.model_r(b)
                phi_i += (1 - self.a) * tf.reduce_mean(tf.square(u_pred1 - self.u))
                i += 1
            
            phi_s = 0
            if snap:
                # calculate snapshot data loss
                phi_s = self.a * tf.reduce_mean(tf.square( self.model_r(self.xs) - self.f(self.xs) ))
            
            # Add phi_b, phi_i, and phi_s to the loss
            loss += phi_b + phi_i + phi_s
            
            return loss, phi_r, phi_b, phi_i, phi_s

        def get_gradient(self, x):
            with tf.GradientTape(persistent=True) as tape:
                # This tape is for derivatives with respect to trainable variables
                tape.watch(self.model_r.trainable_variables)
                if strong:
                    loss, _, _, _ = self.loss_strong(x)
                else:
                    loss, _, _, _, _ = self.loss_weak(x)

            g = tape.gradient(loss, self.model_r.trainable_variables)
            del tape

            return g


        def solve(self, optimizer, batch_size, numEpochs):

            @tf.function
            def train_step(x):
                # Retrieve loss gradient w.r.t. trainable variables
                grad_theta = self.get_gradient(x)

                # Perform gradient descent step
                optimizer.apply_gradients(zip(grad_theta, self.model_r.trainable_variables))
            
            # Split data into training batches
            train_dataset = tf.data.Dataset.from_tensor_slices((self.x,))
            train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
            
            # Iterate training
            for i in range(numEpochs):
                
                # Train on each batch
                for (x_batch_train,) in train_dataset:
                    train_step(x_batch_train)
                
                # Compute loss for full dataset to track training progress
                if strong:
                    self.loss, self.phi_r, self.phi_i, self.phi_s = self.loss_strong(self.x)
                else:
                    self.loss, self.phi_r, self.phi_b, self.phi_i, self.phi_s = self.loss_weak(self.x)
    
    # Return intialized class instance
    return PINN_Solver_Schwarz(model_r, u_int, X_r, X_b, alpha, strong, snap, **kwargs)