In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import numpy as np
from numpy import random as npr
from math import gamma
from math import factorial
from sobol_seq import sobol_seq
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp
import scipy.integrate as integrate

In [2]:
device = torch.device('cuda:0')
device

device(type='cuda', index=0)

In [3]:
n_collocation = 10000
n_validation = 1000
n_initial = 2000
n_boundary = 2000
x_lower = -1
x_upper = 1
y_lower = -1
y_upper = 1

a1 = 1
a2 = 8


def analytical(x,t):
    p = torch.pi
    s1 = torch.sin(a1*p*x)
    s2 = torch.sin(a2*p*t)

    return s1*s2


def right_side(x,t):
    p = torch.pi
    s1 = torch.sin(a1*p*x)
    s2 = torch.sin(a2*p*t)

    return (1 - (a1**2 + a2**2)*p**2)*s1*s2

# Collocation points
x_collocation = (torch.rand(n_collocation) * (x_upper - x_lower) + x_lower).to(device)
y_collocation = (torch.rand(n_collocation) * (y_upper - y_lower) + y_lower).to(device)

#validation points
x_validation = (torch.rand(n_validation) * (x_upper - x_lower) + x_lower).to(device)
y_validation = (torch.rand(n_validation) * (y_upper - y_lower) + y_lower).to(device)

# Initial and boundary condition points
x_bc = (torch.rand(n_initial) * (x_upper - x_lower) + x_lower).to(device)
y_bc_left = y_lower*torch.ones(n_boundary).to(device)
y_bc_right = y_upper*torch.ones(n_boundary).to(device)
u_ybc_left = analytical(x_bc, y_bc_left)
u_ybc_right = analytical(x_bc, y_bc_right)

y_bc = (torch.rand(n_boundary) * (y_upper - y_lower) + y_lower).to(device)
x_bc_left = x_lower*torch.ones(n_boundary).to(device)
x_bc_right = x_upper*torch.ones(n_boundary).to(device)
u_xbc_left = analytical(x_bc_left, y_bc)
u_xbc_right = analytical(x_bc_right, y_bc)

exact = analytical(x_validation, y_validation)
rhs = right_side(x_collocation, y_collocation)

In [4]:
# gaussian wavelet and its derivatives
def wavelet(x,y,jx,jy,kx,ky):
    return (jx*x - kx)*(jy*y - ky)*torch.exp(-((jx*x - kx)**2 + (jy*y - ky)**2)/2)

def D1xwavelet(x,y,jx,jy,kx,ky):
    return jx*(1-(jx*x - kx)**2)*(jy*y - ky)*torch.exp(-((jx*x - kx)**2 + (jy*y - ky)**2)/2)

def D1twavelet(x,y,jx,jy,kx,ky):
    return jy*(1-(jy*y - ky)**2)*(jx*x - kx)*torch.exp(-((jx*x - kx)**2 + (jy*y - ky)**2)/2)

def D2xwavelet(x,y,jx,jy,kx,ky):
    return -(jx**2)*(jx*x - kx)*(jy*y-ky)*(3 - (jx*x - kx)**2)*torch.exp(-((jx*x - kx)**2 + (jy*y - ky)**2)/2)

def D2twavelet(x,y,jx,jy,kx,ky):
    return -(jy**2)*(jx*x - kx)*(jy*y-ky)*(3 - (jy*y - ky)**2)*torch.exp(-((jx*x - kx)**2 + (jy*y - ky)**2)/2)



# resolution
Jx = torch.tensor([-4, -3, -2, -1, 0.0, 1, 2, 3, 4, 5])
Jy = torch.tensor([-4, -3, -2, -1, 0.0, 1, 2, 3, 4, 5])

family = torch.tensor([(2**jx,2**jy,kx,ky) for jx in Jx for jy in Jy for kx in range(-int(1.2*2**jx),int(1.2*2**jx)) for ky in range(-(int(1.2*2**jy)),int(1.2*2**jy))]).to(device)

print(len(family))

# wavelet matrices
Wfamily = torch.stack([wavelet(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DW2x = torch.stack([D2xwavelet(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
DW2y = torch.stack([D2twavelet(x_collocation,y_collocation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

xWbc_left = torch.stack([wavelet(x_bc_left,y_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
xWbc_right = torch.stack([wavelet(x_bc_right,y_bc,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
yWbc_left = torch.stack([wavelet(x_bc,y_bc_left,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T
yWbc_right = torch.stack([wavelet(x_bc,y_bc_right,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

WVal = torch.stack([wavelet(x_validation,y_validation,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T


21316


In [5]:
class WPINN_Network(nn.Module):
    def __init__(self, input_size = n_collocation, 
                 num_hidden_layers1 = 2, 
                 num_hidden_layers2 = 5, 
                 hidden_neurons = 50, 
                 family_size = len(family)):
        
        super(WPINN_Network, self).__init__()
        
        self.activation = nn.Tanh()
        
        # processes each (x,y) point to create single feature
        first_stage_layers = []
        
        # Input layer
        first_stage_layers.append(nn.Linear(2, hidden_neurons)) 
        first_stage_layers.append(self.activation)
        
        for _ in range(num_hidden_layers1):
            first_stage_layers.append(nn.Linear(hidden_neurons, hidden_neurons))
            first_stage_layers.append(self.activation)
        
        # single feature per point
        first_stage_layers.append(nn.Linear(hidden_neurons, 1))
        self.first_stage = nn.Sequential(*first_stage_layers)
        
        #processes all point features to create global coefficients
        second_stage_layers = []
        
        second_stage_layers.append(nn.Linear(input_size, hidden_neurons))
        second_stage_layers.append(self.activation)
        
        for _ in range(num_hidden_layers2):  
            second_stage_layers.append(nn.Linear(hidden_neurons, hidden_neurons))
            second_stage_layers.append(self.activation)
        
        # Final layer outputs the wavelet coefficients
        second_stage_layers.append(nn.Linear(hidden_neurons, family_size))
        self.second_stage = nn.Sequential(*second_stage_layers)
        
        # Initialize weights
        for network in [self.first_stage, self.second_stage]:
            for m in network:
                if isinstance(m, nn.Linear):
                    init.xavier_uniform_(m.weight)
                    init.constant_(m.bias, 0)
        
   
        self.output_layers = nn.ModuleList()
        for i in range(3):
            output_layer = nn.Linear(family_size, 1)
            output_layer.weight.requires_grad = False
            output_layer.bias.data = torch.tensor(0.0 if i > 0 else 0.5)
            output_layer.bias.requires_grad = i == 0
            self.output_layers.append(output_layer)

    def forward(self, x, y, W):
      
        inputs = torch.stack([x, y], dim=-1)  
        
       
        point_features = self.first_stage(inputs) 
        point_features = point_features.squeeze(-1) 
        
        coefficients = self.second_stage(point_features)  
        
        # Generate outputs using the wavelet family
        outputs = []
        for i, layer in enumerate(self.output_layers):
            layer.weight.data = W[i]
            outputs.append(layer(coefficients))

        bias = self.output_layers[0].bias
        
        return coefficients, bias, outputs

model = WPINN_Network().to(device)
optimizer1 = optim.Adam(model.parameters(), lr=0.0001)
c, b, u = model(x_collocation, y_collocation, [Wfamily, DW2x, DW2y])
u[0].shape

torch.Size([10000])

In [6]:
# coefficient refinement network

class CoefficientRefinementNetwork(nn.Module):
    def __init__(self, initial_coefficients, initial_bias, family_size = len(family)):
        
        super(CoefficientRefinementNetwork, self).__init__()
        
        # Store initial coefficients and bias from WPINN network
        self.coefficients = nn.Parameter(initial_coefficients.clone().detach())
        self.bias = nn.Parameter(initial_bias.clone().detach())
        
        self.output_layers = nn.ModuleList()
        for i in range(3):
            output_layer = nn.Linear(family_size, 1)
            output_layer.weight.requires_grad = False
            output_layer.bias.data = (torch.tensor(0.0) if i > 0 else self.bias)
            output_layer.bias.requires_grad = i == 0
            self.output_layers.append(output_layer)

    def forward(self, x, y, W):
        outputs = []
        for i, layer in enumerate(self.output_layers):
            layer.weight.data = W[i]
            outputs.append(layer(self.coefficients))

        bias = self.output_layers[0].bias
        
        return self.coefficients, bias, outputs

In [7]:
def wpinn_loss():   
    # PDE loss at collocation points
    x_interior = x_collocation.clone()
    y_interior = y_collocation.clone()

    global c, b
    c, b, u = model(x_interior, y_interior, [Wfamily, DW2x, DW2y])
    u_pred_xbc_left = torch.mv(xWbc_left, c) + b
    u_pred_xbc_right = torch.mv(xWbc_right, c) + b
    u_pred_ybc_left = torch.mv(yWbc_left, c) + b
    u_pred_ybc_right = torch.mv(yWbc_right, c) + b
    
    pde_loss = torch.mean((u[2] + u[1] + u[0] - rhs) ** 2)
    
    xbc_loss = torch.mean((u_pred_xbc_left - u_xbc_left)** 2) + torch.mean((u_pred_xbc_right - u_xbc_right)** 2)

    ybc_loss = torch.mean((u_pred_ybc_left - u_ybc_left) ** 2) + torch.mean((u_pred_ybc_right - u_ybc_right) ** 2)
    
    total_loss = pde_loss + xbc_loss + ybc_loss
    
    return total_loss, pde_loss, xbc_loss, ybc_loss

def train_pinn(optimizer, num_prints):
    # Training loop
    pde_losses = []
    xbc_losses = []
    ybc_losses = []
    for epoch in range(num_epochs):
        optimizer.zero_grad()

        total_loss, pde_loss, xbc_loss, ybc_loss = wpinn_loss()
        
        total_loss.backward()
        optimizer.step()
        
        pde_losses.append(pde_loss.item())
        xbc_losses.append(xbc_loss.item())
        ybc_losses.append(ybc_loss.item())

        # Validation
        if epoch % ((num_epochs-1)/num_prints) == 0:
            numerical = torch.mv(WVal, c) + b
            errL2 = (torch.sum(torch.abs(exact-numerical)**2))**0.5 / (torch.sum(torch.abs(exact)**2))**0.5
            errMax = torch.max(torch.abs(exact-numerical))
            
            print(f'Epoch [{epoch}/{num_epochs-1}], '
                  f'Total Loss: {total_loss.item():.6f}, '
                  f'PDE Loss: {pde_loss.item():.6f}, '
                  f'xBC Loss: {xbc_loss.item():.6f}, '
                  f'yBC Loss: {ybc_loss.item():.6f}\n\t\t'
                  f'RelativeL2: {errL2},\t\t'
                  f'Max: {errMax}\n' )
    
    return [pde_losses, xbc_losses, ybc_losses]

In [8]:
#call for training

num_epochs = 10**5+1
loss = train_pinn(optimizer1, num_prints=20)

Epoch [0/100000], Total Loss: 107989.210938, PDE Loss: 107988.195312, xBC Loss: 0.506537, yBC Loss: 0.507070
		RelativeL2: 1.440400242805481,		Max: 1.5707181692123413

Epoch [5000/100000], Total Loss: 1.027154, PDE Loss: 1.025680, xBC Loss: 0.001126, yBC Loss: 0.000349
		RelativeL2: 0.008981514722108841,		Max: 0.02645241841673851

Epoch [10000/100000], Total Loss: 0.930214, PDE Loss: 0.930083, xBC Loss: 0.000082, yBC Loss: 0.000049
		RelativeL2: 0.002351896371692419,		Max: 0.007755249738693237

Epoch [15000/100000], Total Loss: 0.051765, PDE Loss: 0.051718, xBC Loss: 0.000029, yBC Loss: 0.000017
		RelativeL2: 0.0011754422448575497,		Max: 0.004682749509811401

Epoch [20000/100000], Total Loss: 0.000313, PDE Loss: 0.000281, xBC Loss: 0.000021, yBC Loss: 0.000011
		RelativeL2: 0.0009113270789384842,		Max: 0.003990165889263153

Epoch [25000/100000], Total Loss: 0.743103, PDE Loss: 0.743083, xBC Loss: 0.000012, yBC Loss: 0.000008
		RelativeL2: 0.0013243354624137282,		Max: 0.0034887716174125

KeyboardInterrupt: 

In [9]:
#coefficient refinement

model = CoefficientRefinementNetwork(initial_coefficients=c, initial_bias = b).to(device)
optimizer2 = optim.Adam(model.parameters(), lr=0.0001)  # Lower learning rate

In [10]:
num_epochs = 10**5+1
l = train_pinn(optimizer2, num_prints=20)

Epoch [0/100000], Total Loss: 0.000255, PDE Loss: 0.000234, xBC Loss: 0.000013, yBC Loss: 0.000007
		RelativeL2: 0.0020190866198390722,		Max: 0.003667157143354416

Epoch [5000/100000], Total Loss: 0.005296, PDE Loss: 0.005289, xBC Loss: 0.000004, yBC Loss: 0.000004
		RelativeL2: 0.00039791909512132406,		Max: 0.0023054543416947126



KeyboardInterrupt: 

In [13]:
# Testing

ntest = 100
xtest = torch.linspace(-1, 1, ntest).to(device)
ttest = torch.linspace(0, 1, ntest).to(device)
    
x_grid, t_grid = torch.meshgrid(xtest, ttest)
x_test = x_grid.reshape(-1)
t_test = t_grid.reshape(-1)

WTest = torch.stack([wavelet(x_test,t_test,family[i,0],family[i,1],family[i,2],family[i,3]) for i in range(len(family))]).T

Uexact = analytical(x_test, t_test).reshape(ntest, ntest)

with torch.no_grad():
    u_pred = torch.mv(WTest, c) + b
    Upred = u_pred.reshape(ntest, ntest)

errL2 = (torch.sum(torch.abs(Uexact-Upred)**2))**0.5 / (torch.sum(torch.abs(Uexact)**2))**0.5
errMax = torch.max(torch.abs(Uexact-Upred))
            
print(f'RelativeL2: {errL2} \t Max: {errMax}' )

RelativeL2: 0.0005659862654283643 	 Max: 0.004936600103974342
