In [43]:
import jax.numpy as jnp
import jax
from jax import vmap, jit, jacfwd
import optax
from functools import partial
import numpy as np
from tqdm import trange
from jax import random
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '../')

from KAN import KAN

def create_interior_points(L_range, H_range, nx, ny):
    x = np.linspace(L_range[0], L_range[1], nx)
    y = np.linspace(H_range[0], H_range[1], ny)
    return x, y

def create_boundary_points(L_range, H_range, nx, ny):
    x = np.linspace(L_range[0], L_range[1], nx)
    y = np.linspace(H_range[0], H_range[1], ny)

    x_left = jnp.array([L_range[0]])
    y_left = y

    x_top = x
    y_top = jnp.array([H_range[1]])

    x_right = jnp.array([L_range[1]])
    y_right = y

    x_bottom = x
    y_bottom = jnp.array([H_range[0]])

    return (x_left, y_left), (x_top, y_top), (x_right, y_right), (x_bottom, y_bottom)

# Define the domain
L_range = (0.0, 1.0)
H_range = (0.0, 1.0)  # Square cavity

# Create interior points using nx and ny
nx, ny = 50, 50  # Grid resolution
x_interior, y_interior = create_interior_points(L_range, H_range, nx, ny)

# Create boundary points using nx and ny
(x_left, y_left), (x_top, y_top), (x_right, y_right), (x_bottom, y_bottom) = create_boundary_points(L_range, H_range, nx, ny)


In [44]:
def plot_domain_setup(x_interior, y_interior, x_left, y_left, 
                     x_top, y_top, x_right, y_right, 
                     x_bottom, y_bottom):
    plt.figure(figsize=(6, 6))
    
    # Plot interior points
    X, Y = np.meshgrid(x_interior, y_interior)
    plt.scatter(X, Y, s=1, alpha=0.5, label='Interior')
    
    # Plot boundary points
    plt.scatter(np.full_like(y_left, x_left[0]), y_left, s=5, c='r', alpha=0.5, label='Left')
    plt.scatter(x_top, np.full_like(x_top, y_top[0]), s=5, c='g', alpha=0.5, label='Top (Lid)')
    plt.scatter(np.full_like(y_right, x_right[0]), y_right, s=5, c='b', alpha=0.5, label='Right')
    plt.scatter(x_bottom, np.full_like(x_bottom, y_bottom[0]), s=5, c='m', alpha=0.5, label='Bottom')
    
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

# Plot sampled points
# plot_domain_setup(x_interior, y_interior, x_left, y_left, 
#                  x_top, y_top, x_right, y_right, 
#                  x_bottom, y_bottom)



In [63]:
class SF_KAN_Separable:
    def __init__(self, layer_dims, init_lr, Re = 100, k=5, r=1):
        self.input_size = layer_dims[0] # input should always be 1 for separable PINNs # TODO: add check later
        self.out_size = layer_dims[-1]
        self.r = r
        self.layer_dims = [self.input_size] + layer_dims[1:-1] + [self.r * self.out_size]
        self.model_x = KAN(layer_dims=self.layer_dims, k=k, const_spl=False, const_res=False, add_bias=True, grid_e=0.02, j='0')
        self.model_y = KAN(layer_dims=self.layer_dims, k=k, const_spl=False, const_res=False, add_bias=True, grid_e=0.02, j='0')
        
        key1, key2 = jax.random.split(jax.random.PRNGKey(10))
        self.variables_x = self.model_x.init(key1, jnp.ones([1, 1]))
        self.variables_y = self.model_y.init(key2, jnp.ones([1, 1]))
        
        self.optimizer = optax.adam(learning_rate=init_lr, nesterov=True)
        self.opt_state_x = self.optimizer.init(self.variables_x['params'])
        self.opt_state_y = self.optimizer.init(self.variables_y['params'])

        self.train_losses = []

    def interpolate_moments(self, mu_old, nu_old, new_shape):
        old_shape = mu_old.shape
        size = old_shape[0]
        old_j = old_shape[1]
        new_j = new_shape[1]
        
        old_indices = jnp.linspace(0, old_j - 1, old_j)
        new_indices = jnp.linspace(0, old_j - 1, new_j)

        interpolate_fn = lambda old_row: jnp.interp(new_indices, old_indices, old_row)

        mu_new = vmap(interpolate_fn)(mu_old)
        nu_new = vmap(interpolate_fn)(nu_old)
        
        return mu_new, nu_new

    def smooth_state_transition(self, old_state, params):
        adam_count = old_state[0].count
        adam_mu, adam_nu = old_state[0].mu, old_state[0].nu

        layer_keys = {k for k in adam_mu.keys() if k.startswith('layers_')}
        
        for key in layer_keys:
            c_shape = params[key]['c_basis'].shape
            mu_new0, nu_new0 = self.interpolate_moments(adam_mu[key]['c_basis'], adam_nu[key]['c_basis'], c_shape)
            adam_mu[key]['c_basis'], adam_nu[key]['c_basis'] = mu_new0, nu_new0

        adam_state = optax.ScaleByAdamState(adam_count, adam_mu, adam_nu)
        extra_state = optax.ScaleByScheduleState(adam_count)
        new_state = (adam_state, extra_state)

        return new_state

    def predict(self, x, y):
        variables_x, variables_y = self.variables_x, self.variables_y
        preds, _ = self.forward_pass(variables_x, variables_y, x, y)
        return preds


    @partial(jit, static_argnums=(0,))
    def forward_pass(self, variables_x, variables_y, x, y):
        preds_x, spl_regs_x = self.model_x.apply(variables_x, x[:, None])
        preds_y, spl_regs_y = self.model_y.apply(variables_y, y[:, None])
        
        # print(f"preds_x shape = {preds_x.shape}")
        # print(f"preds_y shape = {preds_y.shape}")

        preds_x = preds_x.reshape(-1, self.out_size, self.r)
        preds_y = preds_y.reshape(-1, self.out_size, self.r)
        # preds = jnp.einsum('ir,jr->ijr', preds_x, preds_y)
        preds = jnp.einsum('ijk,ljk->ilj', preds_x, preds_y)
        # print(f"preds shape = {preds.shape}")

        spl_regs = spl_regs_x + spl_regs_y
        
        return preds, spl_regs

    @partial(jit, static_argnums=(0,))
    def loss(self, params_x, params_y, state_x, state_y, *args):
        variables_x = {'params': params_x, 'state': state_x}
        variables_y = {'params': params_y, 'state': state_y}
        return self.loss_fn(variables_x, variables_y, *args)

    @partial(jit, static_argnums=(0,))
    def train_step(self, params_x, params_y, state_x, state_y, opt_state_x, opt_state_y, *args):
        (loss_value, (physics_loss, boundary_loss)), grads = jax.value_and_grad(self.loss, has_aux=True, argnums=(0,1))(
            params_x, params_y, state_x, state_y, *args
        )
        grads_x, grads_y = grads

        updates_x, opt_state_x = self.optimizer.update(grads_x, opt_state_x)
        updates_y, opt_state_y = self.optimizer.update(grads_y, opt_state_y)

        params_x = optax.apply_updates(params_x, updates_x)
        params_y = optax.apply_updates(params_y, updates_y)

        return params_x, params_y, opt_state_x, opt_state_y, loss_value, physics_loss, boundary_loss

    def train(self, num_epochs, *args):
        params_x, state_x = self.variables_x['params'], self.variables_x['state']
        params_y, state_y = self.variables_y['params'], self.variables_y['state']
        opt_state_x, opt_state_y = self.opt_state_x, self.opt_state_y
        loss_history = []

        pbar = trange(num_epochs, smoothing=0.)
        for epoch in pbar:                
            params_x, params_y, opt_state_x, opt_state_y, loss_value, physics_loss, boundary_loss = self.train_step(
                params_x, params_y, state_x, state_y, opt_state_x, opt_state_y, *args
            )
            loss_history.append(loss_value)

            if epoch % 10 == 0:
                pbar.set_postfix({
                    'Total Loss': f"{loss_value:.4e}",
                    'Physics Loss': f"{physics_loss:.4e}",
                    'Boundary Loss': f"{boundary_loss:.4e}"
                })
        
        self.variables_x = {'params': params_x, 'state': state_x}
        self.variables_y = {'params': params_y, 'state': state_y}
        return loss_history


class Cavity_SF_KAN_Separable(SF_KAN_Separable):
    def __init__(self, *args, Re=100.0, r=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.Re = Re

    @partial(jit, static_argnums=(0,))
    def loss_fn(self, variables_x, variables_y, x_interior, y_interior, x_left, y_left, x_top, y_top, x_right, y_right, x_bottom, y_bottom):
        # Compute physics residuals
        residuals = self.compute_residuals(variables_x, variables_y, x_interior, y_interior)
        physics_loss = jnp.mean(jnp.square(residuals))

        # Boundary losses
        preds_left, _ = self.forward_pass(variables_x, variables_y, x_left, y_left)
        preds_top, _ = self.forward_pass(variables_x, variables_y, x_top, y_top)
        preds_right, _ = self.forward_pass(variables_x, variables_y, x_right, y_right)
        preds_bottom, _ = self.forward_pass(variables_x, variables_y, x_bottom, y_bottom)

        left_loss = jnp.mean(jnp.square(preds_left[..., 0]) + jnp.square(preds_left[..., 1]))
        top_loss = jnp.mean(jnp.square(preds_top[..., 0] - 1.0) + jnp.square(preds_top[..., 1]))
        right_loss = jnp.mean(jnp.square(preds_right[..., 0]) + jnp.square(preds_right[..., 1]))
        bottom_loss = jnp.mean(jnp.square(preds_bottom[..., 0]) + jnp.square(preds_bottom[..., 1]))

        boundary_loss = left_loss + top_loss + right_loss + bottom_loss

        # Total loss
        total_loss = physics_loss + boundary_loss

        return total_loss, (physics_loss, boundary_loss)

    @partial(jit, static_argnums=(0,))
    def compute_residuals(self, variables_x, variables_y, x_interior, y_interior):
        def model_x_func(x):
            # print(f"Shape of x = {x.shape}")
            x_feat = self.model_x.apply(variables_x, x.reshape(-1, 1))[0]
            # print(f"Shape of x_feat = {x_feat.shape}")
            x_feat = x_feat.reshape(self.out_size, self.r) 
            # print(f"Shape of x_feat after reshape = {x_feat.shape}")
            return x_feat

        def model_y_func(y):
            y_feat = self.model_y.apply(variables_y, y.reshape(-1, 1))[0]
            y_feat = y_feat.reshape(self.out_size, self.r)
            return y_feat

        def model_x_grad(x):
            return jacfwd(model_x_func)(x)

        def model_y_grad(y):
            return jacfwd(model_y_func)(y)

        def model_x_hess(x):
            return jacfwd(jacfwd(model_x_func))(x)

        def model_y_hess(y):
            return jacfwd(jacfwd(model_y_func))(y)

        x_feats = vmap(model_x_func)(x_interior)
        y_feats = vmap(model_y_func)(y_interior)
        x_grads = vmap(model_x_grad)(x_interior)
        y_grads = vmap(model_y_grad)(y_interior)
        x_hess = vmap(model_x_hess)(x_interior)
        y_hess = vmap(model_y_hess)(y_interior)

        u_x, v_x, p_x = x_feats[:, 0, :], x_feats[:, 1, :], x_feats[:, 2, :]
        u_y, v_y, p_y = y_feats[:, 0, :], y_feats[:, 1, :], y_feats[:, 2, :]

        du_x_dx, dv_x_dx, dp_x_dx = x_grads[:, 0, :], x_grads[:, 1, :], x_grads[:, 2, :]
        du_y_dy, dv_y_dy, dp_y_dy = y_grads[:, 0, :], y_grads[:, 1, :], y_grads[:, 2, :]

        d2u_x_dx2, d2v_x_dx2 = x_hess[:, 0, :], x_hess[:, 1, :]
        d2u_y_dy2, d2v_y_dy2 = y_hess[:, 0, :], y_hess[:, 1, :]

        u = jnp.einsum('ir,jr->ij', u_x, u_y)
        v = jnp.einsum('ir,jr->ij', v_x, v_y)
        p = jnp.einsum('ir,jr->ij', p_x, p_y)

        du_dx = jnp.einsum('ir,jr->ij', du_x_dx, u_y)
        du_dy = jnp.einsum('ir,jr->ij', u_x, du_y_dy)
        dv_dx = jnp.einsum('ir,jr->ij', dv_x_dx, v_y)
        dv_dy = jnp.einsum('ir,jr->ij', v_x, dv_y_dy)
        dp_dx = jnp.einsum('ir,jr->ij', dp_x_dx, p_y)
        dp_dy = jnp.einsum('ir,jr->ij', p_x, dp_y_dy)

        d2u_dx2 = jnp.einsum('ir,jr->ij', d2u_x_dx2, u_y)
        d2u_dy2 = jnp.einsum('ir,jr->ij', u_x, d2u_y_dy2)
        d2v_dx2 = jnp.einsum('ir,jr->ij', d2v_x_dx2, v_y)
        d2v_dy2 = jnp.einsum('ir,jr->ij', v_x, d2v_y_dy2)

        continuity = du_dx + dv_dy
        momentum_x = u * du_dx + v * du_dy + dp_dx - (1/self.Re) * (d2u_dx2 + d2u_dy2)
        momentum_y = u * dv_dx + v * dv_dy + dp_dy - (1/self.Re) * (d2v_dx2 + d2v_dy2)

        return jnp.stack([continuity, momentum_x, momentum_y], axis=-1)


In [None]:
# Model parameters
layer_dims = [1, 5, 5, 3]  # Input dim is always 1, output dim is 3 (u, v, p)
init_lr = 1e-3
Re = 100.0
k=5
r=5


model = Cavity_SF_KAN_Separable(
    layer_dims=layer_dims,
    init_lr=init_lr,
    Re=Re,
    k=k,
    r=r
)

num_epochs = 50000
loss_history = model.train(
    num_epochs,
    x_interior, y_interior,
    x_left, y_left,
    x_top, y_top,
    x_right, y_right,
    x_bottom, y_bottom
)



100%|██████████| 50000/50000 [02:17<00:00, 362.80it/s, Total Loss=6.8732e-02, Physics Loss=2.9686e-02, Boundary Loss=3.9046e-02]


In [None]:
# Plot loss history
def plot_loss_history(loss_history):
    plt.figure(figsize=(6, 5))
    epochs = range(0, len(loss_history), 100)
    plt.plot(epochs, loss_history[::100])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')  # Use log scale for y-axis
    plt.grid(True)
    plt.show()

plot_loss_history(loss_history)


In [66]:
# Predictions for plotting
nx_plot, ny_plot = 100, 100
x_plot = jnp.linspace(L_range[0], L_range[1], nx_plot)
y_plot = jnp.linspace(L_range[0], L_range[1], ny_plot)

X_mesh, Y_mesh = jnp.meshgrid(x_plot, y_plot)

# Make predictions
uvp_pred = model.predict(x_plot, y_plot)

# Extract u, v, p 
U = np.array(uvp_pred[:, :, 0]).T
V = np.array(uvp_pred[:, :, 1]).T
P = np.array(uvp_pred[:, :, 2]).T

vmag = np.sqrt(U**2 + V**2)


output_data = {
    'mesh': {
        'x_mesh': X_mesh,
        'y_mesh': Y_mesh,
        'L_range': L_range
    },
    'field_variables': {
        'u': U,
        'v': V,
        'vmag': vmag,  # velocity magnitude
        'p': P
    },
    'parameters': {
        'Re': Re
    },
    'training': {
        'loss_history': loss_history
    }
}

# Save data
np.save(f'./data/2d_ns_spikan_Re_{Re}_nx{nx}_ny{ny}_epochs{num_epochs}_{layer_dims}.npy', output_data)


In [None]:
# Plotting
fig1, axs1 = plt.subplots(1, 4, figsize=(20, 5))  # 1 row, 4 columns

im00 = axs1[0].contourf(X_mesh, Y_mesh, U, levels=50, cmap='RdBu_r')
axs1[0].set_title(r'Predicted $u$, SPIKAN')
plt.colorbar(im00, ax=axs1[0])

im01 = axs1[1].contourf(X_mesh, Y_mesh, V, levels=50, cmap='RdBu_r')
axs1[1].set_title(r'Predicted $v$, SPIKAN')
plt.colorbar(im01, ax=axs1[1])

im02 = axs1[2].contourf(X_mesh, Y_mesh, vmag, levels=50, cmap='RdBu_r')
axs1[2].set_title(r'Predicted $v_{mag}$, SPIKAN')
plt.colorbar(im02, ax=axs1[2])

im03 = axs1[3].contourf(X_mesh, Y_mesh, P, levels=50, cmap='RdBu_r')
axs1[3].set_title(r'Predicted $p$, SPIKAN')
plt.colorbar(im03, ax=axs1[3])

for ax in axs1:
    ax.set_aspect('equal')
    ax.set_xlim(L_range[0], L_range[1])
    ax.set_ylim(L_range[0], L_range[1])
    ax.set_xlabel('x')
    ax.set_ylabel('y')

plt.tight_layout()
plt.show()

In [29]:
def count_trainable_params(variables_x, variables_y):
    # Flatten the parameter trees
    flat_params_x, _ = jax.tree_util.tree_flatten(variables_x['params'])
    flat_params_y, _ = jax.tree_util.tree_flatten(variables_y['params'])
    
    # Sum the number of elements in each parameter array for both models
    total_params_x = sum(p.size for p in flat_params_x)
    total_params_y = sum(p.size for p in flat_params_y)
    
    # Total trainable parameters in the entire separable model
    total_params = total_params_x + total_params_y
    return total_params

# Assuming `variables_x` and `variables_y` have been initialized in your model
num_params = count_trainable_params(model.variables_x, model.variables_y)
print(f"Number of trainable parameters: {num_params}")

Number of trainable parameters: 1704
