In [56]:
from kan import KAN, LBFGS
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import autograd
from tqdm import tqdm
import time

In [57]:
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)

In [63]:
def loss_fun(x_int, x_bc, xdim, model):
    mu = 1
    h = 1e-5  # Small interval for finite differences

    x, t = x_int[:, :-1], x_int[:, -1:]
    t.requires_grad_()
    x.requires_grad_()

    # Combine x and t to create input tensor for the model
    input_tensor = torch.cat((x, t), dim=1)
    u = model(input_tensor)

    # Initialize tensors for derivatives
    du_dt = torch.zeros_like(u)
    du_dx = torch.zeros_like(x)
    d2u_dx2 = torch.zeros_like(x)

    # First-order time derivative using Euler's approximation
    # for i in range(50):
    #     # Compute u(t+h) and u(t-h)
    #     t_forward = (t[i] + h).view(1, 1)  # Ensure shape is [1, 1]
    #     t_backward = (t[i] - h).view(1, 1)  # Ensure shape is [1, 1]
    #     u_forward = model(torch.cat((x[i:i+1], t_forward), dim=1))
    #     u_backward = model(torch.cat((x[i:i+1], t_backward), dim=1))

    #     # Euler's approximation for first derivative
    #     du_dt[i] = (u_forward - u_backward) / (2 * h)
    du_dt = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(t), create_graph=True)[0]
    

    # First-order spatial derivative using Euler's approximation
    # for i in range(50):
    #     for j in range(3):  # Loop through each dimension of x
    #         # Compute u(x+h) and u(x-h)
    #         x_forward = x[i:i+1].clone()
    #         x_backward = x[i:i+1].clone()
    #         x_forward[0, j] += h
    #         x_backward[0, j] -= h
            
    #         u_forward = model(torch.cat((x_forward, t[i:i+1]), dim=1))
    #         u_backward = model(torch.cat((x_backward, t[i:i+1]), dim=1))

    #         # Euler's approximation for first derivative
    #         du_dx[i, j] = (u_forward - u_backward) / (2 * h)
    du_dx = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, allow_unused=True)[0]

    # Second-order spatial derivative using Euler's approximation
    for i in range(50):
        for j in range(3):  # Loop through each dimension of x
            # Compute u(x+h), u(x), and u(x-h)
            x_forward = x[i:i+1].clone()
            x_backward = x[i:i+1].clone()
            x_forward[0, j] += h
            x_backward[0, j] -= h
            
            u_forward = model(torch.cat((x_forward, t[i:i+1]), dim=1))
            u_current = u[i]
            u_backward = model(torch.cat((x_backward, t[i:i+1]), dim=1))

            # Euler's approximation for second derivative
            d2u_dx2[i, j] = (u_forward - 2 * u_current + u_backward) / (h ** 2)

    # Compute the residual R_int
    R_int = torch.mean((du_dt.squeeze(1) + torch.sum(d2u_dx2, dim=1) - mu * torch.sum(du_dx ** 2, dim=1)) ** 2)

    # Boundary condition handling
    x_bc, t_bc = x_bc[:, :-1], x_bc[:, -1:]
    t_bc.requires_grad_()
    input_tensor_bc = torch.cat((x_bc, t_bc), dim=1)
    u_bc = model(input_tensor_bc)

    R_bc = torch.mean(torch.square(u_bc - torch.log((1 + torch.norm(x_bc, p=2) ** 2) / 2)))

    return R_int, R_bc



# def loss_fun(x_int, x_bc, xdim, model):
#     mu = 1
#     # print(model)
#     x, t = x_int[:, :-1], x_int[:, -1:]
#     # print("Shape of x: " + str(x.shape) + ", Shape of t: " + str(t.shape))
#     t.requires_grad_()
#     x.requires_grad_()

#     input_tensor = torch.cat((x, t), dim=1)
#     # print("Shape of input tensor: " + str(input_tensor.shape))
#     u = model(input_tensor)
#     # print("Shape of output tensor: " + str(u.shape))

#     du_dt = torch.zeros_like(u)

#     # Loop through each of the 50 values
#     for i in range(50):
#         if i == 0:
#             # Forward difference for the first element
#             delta_t = t[i + 1] - t[i]
#             du_dt[i] = (u[i + 1] - u[i]) / delta_t
#         elif i == 49:
#             # Backward difference for the last element
#             delta_t = t[i] - t[i - 1]
#             du_dt[i] = (u[i] - u[i - 1]) / delta_t
#         else:
#             # Central difference for the interior elements
#             delta_t_forward = t[i + 1] - t[i]
#             delta_t_backward = t[i] - t[i - 1]
#             du_dt[i] = (u[i + 1] - u[i - 1]) / (delta_t_forward + delta_t_backward)
#     # print("Euler's du_dt: " + str(du_dt[:10]))

#     # du_dt = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(t), create_graph=True)[0]
#     # print("Autograd's du_dt: " + str(du_dt[:10]))

#     # Display the first-order derivative array
#     # print("First-order derivative of u with respect to t:")
#     # print(du_dt)

#     # First-order time derivative using finite differences
#     # du_dt = torch.zeros_like(u)

#     # for i in range(50):
#     #     if i == 0:
#     #         # Forward difference for the first element
#     #         delta_t = t[i + 1] - t[i]
#     #         du_dt[i] = (u[i + 1] - u[i]) / delta_t
#     #     elif i == 49:
#     #         # Backward difference for the last element
#     #         delta_t = t[i] - t[i - 1]
#     #         du_dt[i] = (u[i] - u[i - 1]) / delta_t
#     #     else:
#     #         # Central difference for the interior elements
#     #         delta_t_forward = t[i + 1] - t[i]
#     #         delta_t_backward = t[i] - t[i - 1]
#     #         du_dt[i] = (u[i + 1] - u[i - 1]) / (delta_t_forward + delta_t_backward)
#     # print("Shape of du_dt: " + str(du_dt.shape))
#     # print(du_dt)

#     # First-order spatial derivative using finite differences
#     # du_dx = torch.zeros_like(u)

#     # for i in range(50):
#     #     if i == 0:
#     #         # Forward difference for the first element
#     #         delta_x = x[i + 1] - x[i]
#     #         du_dx[i] = (u[i + 1] - u[i]) / delta_x
#     #     elif i == 49:
#     #         # Backward difference for the last element
#     #         delta_x = x[i] - x[i - 1]
#     #         du_dx[i] = (u[i] - u[i - 1]) / delta_x
#     #     else:
#     #         # Central difference for the interior elements
#     #         delta_x_forward = x[i + 1] - x[i]
#     #         delta_x_backward = x[i] - x[i - 1]
#     #         du_dx[i] = (u[i + 1] - u[i - 1]) / (delta_x_forward + delta_x_backward)
#     # print("Shape of du_dx: " + str(du_dx.shape))

#     du_dx = torch.zeros_like(x)

#     # Loop through each of the 50 values
#     for i in range(50):
#         for j in range(3):  # Loop through each dimension of x
#             if i == 0:
#                 # Forward difference for the first element
#                 delta_x = x[i + 1, j] - x[i, j]
#                 du_dx[i, j] = (u[i + 1] - u[i]) / delta_x
#             elif i == 49:
#                 # Backward difference for the last element
#                 delta_x = x[i, j] - x[i - 1, j]
#                 du_dx[i, j] = (u[i] - u[i - 1]) / delta_x
#             else:
#                 # Central difference for the interior elements
#                 delta_x_forward = x[i + 1, j] - x[i, j]
#                 delta_x_backward = x[i, j] - x[i - 1, j]
#                 du_dx[i, j] = (u[i + 1] - u[i - 1]) / (delta_x_forward + delta_x_backward)


#     d2u_dx2 = torch.zeros_like(x)

#     # Loop through each dimension of x (3 dimensions in this case)
#     for j in range(3):
#         # Loop through each of the 50 values
#         for i in range(50):
#             if i == 0:
#                 # Forward difference for the first element
#                 delta_x = x[i + 1, j] - x[i, j]
#                 d2u_dx2[i, j] = (u[i + 2] - 2 * u[i + 1] + u[i]) / (delta_x ** 2)
#             elif i == 49:
#                 # Backward difference for the last element
#                 delta_x = x[i, j] - x[i - 1, j]
#                 d2u_dx2[i, j] = (u[i] - 2 * u[i - 1] + u[i - 2]) / (delta_x ** 2)
#             else:
#                 # Central difference for the interior elements
#                 delta_x_forward = x[i + 1, j] - x[i, j]
#                 delta_x_backward = x[i, j] - x[i - 1, j]
#                 delta_x = (delta_x_forward + delta_x_backward) / 2
                
#                 u_plus_h = u[i + 1]
#                 u_current = u[i]
#                 u_minus_h = u[i - 1]
                
#                 d2u_dx2[i, j] = (u_plus_h - 2 * u_current + u_minus_h) / (delta_x ** 2)

#     # Display the second-order derivative tensor
#     # print("Second-order derivative of u with respect to each dimension of x:")
#     # print(d2u_dx2)

#     # # Second-order spatial derivative using finite differences
#     # d2u_dx2 = torch.zeros_like(u)
#     # print("Printing all values for inspection:")
#     # print("u[0]: " + str(u[0]))
#     # print("u[1]: " + str(u[1]))
#     # print("u[2]: " + str(u[2]))
#     # print("x[0]: " + str(x[0]))
#     # print("x[1]: " + str(x[1]))
#     # print("delta_x: " + str(x[1] - x[0]))
#     # print()
#     # # Loop through each of the 50 values
#     # for i in range(50):
#     #     if i == 0:
#     #         # Forward difference for the first element
#     #         print("Before tensor operation")
#     #         delta_x = x[i + 1] - x[i]
#     #         print("After tensor operation")
#     #         d2u_dx2[i] = (u[i + 2] - 2 * u[i + 1] + u[i]) / (delta_x ** 2)
#     #         print("After tensor operation 1")
#     #     elif i == 49:
#     #         # Backward difference for the last element
#     #         delta_x = x[i] - x[i - 1]
#     #         print("After tensor operation 2")
#     #         d2u_dx2[i] = (u[i] - 2 * u[i - 1] + u[i - 2]) / (delta_x ** 2)
#     #         print("After tensor operation 3")
#     #     else:
#     #         # Central difference for the interior elements
#     #         delta_x_forward = x[i + 1] - x[i]
#     #         print("After tensor operation 4")
#     #         delta_x_backward = x[i] - x[i - 1]
#     #         print("After tensor operation 5")
#     #         delta_x = (delta_x_forward + delta_x_backward) / 2
#     #         print("After tensor operation 6")
            
#     #         u_plus_h = u[i + 1]
#     #         u_current = u[i]
#     #         u_minus_h = u[i - 1]
#     #         print("After tensor operation 7")
#     #         d2u_dx2[i] = (u_plus_h - 2 * u_current + u_minus_h) / (delta_x ** 2)
#     #         print("After tensor operation 8")

#     # # Display the second-order derivative tensor
#     # print("Second-order derivative of u with respect to x:")
#     # print(d2u_dx2)
#     # print("Shape of d2u_dx2: " + str(d2u_dx2.shape))

#     # Compute the residual R_int
#     R_int = torch.mean((du_dt.squeeze(1) + torch.sum(d2u_dx2, dim=1) - mu*torch.sum(du_dx**2, dim=1))**2)
#     # R_int = torch.mean((du_dt.squeeze(1) + d2u_dx2.squeeze(1) - mu * (du_dx.squeeze(1) ** 2)) ** 2)

#     # Boundary condition handling
#     x_bc, t_bc = x_bc[:, :-1], x_bc[:, -1:]
#     t_bc.requires_grad_()
#     input_tensor_bc = torch.cat((x_bc, t_bc), dim=1)
#     u_bc = model(input_tensor_bc)

#     R_bc = torch.mean(torch.square(u_bc - torch.log((1 + torch.norm(x_bc, p=2) ** 2) / 2)))

#     return R_int, R_bc



# # def loss_fun(x_int, x_bc, xdim, model):

# #     mu = 1

# #     x, t = x_int[:, :-1], x_int[:, -1:]
# #     print("Shape of x: " + str(x.shape) + ", Shape of t: " + str(t.shape))
# #     t.requires_grad_()
# #     x.requires_grad_()

# #     input_tensor = torch.cat((x, t), dim=1)
# #     print("Shape of input tensor: " + str(input_tensor.shape))
# #     u = model(input_tensor)
# #     print("Shape of output tensor: " + str(u.shape))


# #     du_dt = torch.zeros_like(u)

# #     # Loop through each of the 50 values
# #     for i in range(50):
# #         if i == 0:
# #             # Forward difference for the first element
# #             delta_t = t[i + 1] - t[i]
# #             du_dt[i] = (u[i + 1] - u[i]) / delta_t
# #         elif i == 49:
# #             # Backward difference for the last element
# #             delta_t = t[i] - t[i - 1]
# #             du_dt[i] = (u[i] - u[i - 1]) / delta_t
# #         else:
# #             # Central difference for the interior elements
# #             delta_t_forward = t[i + 1] - t[i]
# #             delta_t_backward = t[i] - t[i - 1]
# #             du_dt[i] = (u[i + 1] - u[i - 1]) / (delta_t_forward + delta_t_backward)
# #     # print("Shape of t: " + str(t.shape))
# #     # du_dt = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(t), create_graph=True)[0]
# #     print("Shape of du_dt: " + str(du_dt.shape))


# #     du_dx = torch.zeros_like(u)

# #     # Loop through each of the 50 values
# #     for i in range(50):
# #         if i == 0:
# #             # Forward difference for the first element
# #             delta_x = x[i + 1] - x[i]
# #             du_dx[i] = (u[i + 1] - u[i]) / delta_x
# #         elif i == 49:
# #             # Backward difference for the last element
# #             delta_x = x[i] - x[i - 1]
# #             du_dx[i] = (u[i] - u[i - 1]) / delta_x
# #         else:
# #             # Central difference for the interior elements
# #             delta_x_forward = x[i + 1] - x[i]
# #             delta_x_backward = x[i] - x[i - 1]
# #             du_dx[i] = (u[i + 1] - u[i - 1]) / (delta_x_forward + delta_x_backward)


# #     # print("Shape of x: " + str(x.shape))
# #     # du_dx = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, allow_unused=True)[0]
# #     print("Shape of du_dx: " + str(du_dx.shape))

# #     # for i in range(50):
# #     #     # Compute the gradient of u[i] with respect to x[i]
# #     #         du_dx = torch.autograd.grad(u[i], x[i][j], grad_outputs=torch.ones_like(u[i]), create_graph=True)[0]
        
# #     #         # Check if the gradient is None (if the tensor wasn't used in the computation)
# #     #         if du_dx is not None:
# #     #             print("Shape of du_dx: " + str(du_dx.shape))
# #     #             # Sum the gradients to get a single scalar value
# #     #             grad_sum += du_dx[i]
# #     #         else:
# #     #             print("du_dx was found to be None.")

# #     # print(du_dx)
# #     # du_dx = torch.autograd.grad(u, x, create_graph=True)[0].sum()
# #     # print("Shape of du_dx: " + str(du_dx.shape))

# #     d2u_dx2 = torch.zeros_like(u)

# #     # Loop through each of the 50 values
# #     for i in range(50):
# #         if i == 0:
# #             # Forward difference for the first element
# #             delta_x_forward = x[i + 1] - x[i]
# #             u_plus_h = u[i + 1]
# #             u_current = u[i]
# #             d2u_dx2[i] = (u_plus_h - 2 * u_current + u_current) / (delta_x_forward ** 2)
# #         elif i == 49:
# #             # Backward difference for the last element
# #             delta_x_backward = x[i] - x[i - 1]
# #             u_minus_h = u[i - 1]
# #             u_current = u[i]
# #             d2u_dx2[i] = (u_current - 2 * u_current + u_minus_h) / (delta_x_backward ** 2)
# #         else:
# #             # Central difference for the interior elements
# #             delta_x_forward = x[i + 1] - x[i]
# #             delta_x_backward = x[i] - x[i - 1]
# #             delta_x = (delta_x_forward + delta_x_backward) / 2
            
# #             u_plus_h = u[i + 1]
# #             u_current = u[i]
# #             u_minus_h = u[i - 1]
            
# #             d2u_dx2[i] = (u_plus_h - 2 * u_current + u_minus_h) / (delta_x ** 2)

# #     # print(d2u_dx2.shape)
# #     # print(d2u_dx2)

# #     # print("Jacobian Test start")
# #     # d2u_dx2 = torch.autograd.functional.jacobian(lambda x: du_dx, x)
# #     # print("Jacobian Test end")

# #     print("Shape of d2u_dx2: ", d2u_dx2.shape)

# #     # print("du_dx gradient enabled.")
# #     # du_dx.requires_grad_()
# #     # print(du_dx.requires_grad)
# #     # print("du_dx grad_fn: ", du_dx.grad_fn)
# #     # print("x grad_fn: ",x.grad_fn)
# #     # d2u_dx2 = []
# #     # for i in range(xdim):
# #     #     # (batch_size, 1)
# #     #     # print("Gradient calculation start")
# #     #     # temp = du_dx[:, i].sum()
# #     #     # print("Input Grad:", temp.requires_grad)
# #     #     d2u_dxidxi = torch.autograd.grad(du_dx[:, i].sum(), x, grad_outputs=torch.ones_like(du_dx[:, i]), retain_graph=True, allow_unused=True)[0]
# #     #     # [:, i:i+1]
# #     #     print(d2u_dxidxi)
# #     #     print("First gradient calculated.")
# #     #     d2u_dx2.append(d2u_dxidxi)
# #     # # (batch_size, x_dim)
# #     # d2u_dx2 = torch.concat(d2u_dx2, dim=1)

# #     # d2u_dx2 = []
# #     # for i in range(xdim):
# #     #     d2u_dxidxi = torch.autograd.grad(du_dx[:, i].sum(), x, retain_graph=True, allow_unused=True)[0][:, i:i+1]
# #     #     if d2u_dxidxi is None:
# #     #         raise RuntimeError(f'Gradient w.r.t x[:, {i}] is None')
# #     #     d2u_dx2.append(d2u_dxidxi)

# #     # d2u_dx2 = torch.cat(d2u_dx2, dim=1)

# #     R_int = torch.mean((du_dt.squeeze(1) + torch.sum(d2u_dx2, dim=1) - mu*torch.sum(du_dx**2, dim=1))**2)
# #     # R_int = torch.mean((du_dt.squeeze(1) - mu*torch.sum(du_dx**2, dim=1))**2)

# #     x, t = x_bc[:, :-1], x_bc[:, -1:]
# #     t.requires_grad_()
# #     input_tensor = torch.cat((x, t), dim=1)
# #     u_bc = model(input_tensor)   

# #     R_bc =  torch.mean(torch.square(u_bc - torch.log((1 + torch.norm(x, p=2) ** 2) / 2)))

# #     return R_int, R_bc

In [59]:
def get_data():
    global xdim, T
    
    xdim = 3
    T = 1

    tensor1 = torch.randn((50, xdim), device=device)
    tensor2 = torch.rand((50, 1), device=device)*T
    x_int = torch.cat([tensor1, tensor2], dim=1)

    # print("tensor1.shape: ", tensor1.shape)
    # print("tensor2.shape: ", tensor2.shape)
    # print("x_int.shape: ", x_int.shape)
    
    tensor1 = torch.randn((50, xdim), device=device)
    tensor2 = torch.ones((50, 1), device=device)*T
    x_bc = torch.cat([tensor1, tensor2], dim=1)
    
    return x_int, x_bc

In [60]:
lambda_b       = 10.0
lambda_ic      =10.0

steps = 40
alpha = 0.1
log = 1

N = 50
xdim = 3

global loss_int_hist, loss_bc_hist, loss_ic_hist, pred_hist

pred_hist      = np.zeros(N)

model = KAN(width=[xdim+1,100, 40, 1], grid=5, k=3, grid_eps=1.0, noise_scale_base=0.25)
optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)

# print(model)

In [61]:
def train(steps):

    loss_int_hist  = np.zeros(steps)
    loss_bc_hist    = np.zeros(steps)
    # loss_ic_hist    = np.zeros(steps)
    
    pbar = tqdm(range(steps), desc='description')
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-7)

    for epoch in pbar:
        def closure():
            global loss_int, loss_bc, x_int
            # print("Hello 1")
            # zero the gradient buffers
            optimizer.zero_grad()
            # print("Hello 2")
            x_int, x_bc = get_data()

            # x_int = x_int.reshape(-1, 1)
            # x_bc = x_bc.reshape(-1, 1)
            
            # print(x_int.shape)
            # print(x_bc.shape)

            # compute losses
            # print("Hello 3")
            loss_int, loss_bc = loss_fun(x_int, x_bc, xdim, model)
            loss = loss_int + lambda_b*loss_bc

            # print("Hello 4")
            # compute gradients of training loss
            loss.backward()
            
            return loss
        
        # print("Hello 5")
        x_int, x_bc = get_data()
        # print(x_int.shape)
        # print(x_bc.shape)

        # if epoch % 5 == 0 and epoch < 50:
        #     model.update_grid_from_samples(x_int)

        # print("Hello 6")
        optimizer.step(closure)
        loss = loss_int + lambda_b*loss_bc

        # print("Hello 7")
        if epoch % log == 0:
            pbar.set_description("interior pde loss: %.2e | bc loss: %.2e " % (loss_int.cpu().detach().numpy(), loss_bc.cpu().detach().numpy()))

        # print(f'   --- epoch {epoch+1}: loss_int = {loss_int.item():.4e}, loss_bc = {loss_bc.item():.4e}, loss_ic = {loss_ic.item():.4e}')
        
        # save loss
        loss_int_hist[epoch] = loss_int
        loss_bc_hist[epoch] = loss_bc

In [64]:
# Measure execution time
start_time = time.time()

train(steps)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Training completed in {elapsed_time:.2f} seconds.")

description:   0%|          | 0/40 [00:00<?, ?it/s]