In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torch.optim import LBFGS
from tqdm import tqdm
import scipy.io

from DCGD_BFGS import DualCenter_BFGS

from util import *
from model.pinn import PINNs
from model.pinnsformer import PINNsformer

In [None]:
# CUDA support 
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

#device = 'cuda:0'

In [None]:
# Train PINNsformer
res, b_left, b_right, b_upper, b_lower = get_data([0,2*np.pi], [0,1], 51, 51)
res_test, _, _, _, _ = get_data([0,2*np.pi], [0,1], 101, 101)

res = make_time_sequence(res, num_step=5, step=1e-4)
b_left = make_time_sequence(b_left, num_step=5, step=1e-4)
b_right = make_time_sequence(b_right, num_step=5, step=1e-4)
b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)
b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)

res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, t_res = res[:,:,0:1], res[:,:,1:2]
x_left, t_left = b_left[:,:,0:1], b_left[:,:,1:2]
x_right, t_right = b_right[:,:,0:1], b_right[:,:,1:2]
x_upper, t_upper = b_upper[:,:,0:1], b_upper[:,:,1:2]
x_lower, t_lower = b_lower[:,:,0:1], b_lower[:,:,1:2]

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
model = PINNsformer(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)

model.apply(init_weights)
optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')

print(model)
print(get_n_params(model))

In [None]:
weight_optimizer = DualCenter_BFGS(optim, num_pde=1)

In [None]:
loss_track = []
wr, wb = 1, 1



for i in tqdm(range(500)):
    def closure():
        pred_res = model(x_res, t_res)
        pred_left = model(x_left, t_left)
        pred_right = model(x_right, t_right)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)

        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]

        loss_res = torch.mean((u_t + 50 * u_x) ** 2)
        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
        loss_ic = torch.mean((pred_left[:,0] - torch.sin(x_left[:,0])) ** 2)

        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])

        loss = wr*loss_res + wb*(loss_bc + loss_ic)
        optim.zero_grad()
        loss.backward()
        return loss

    def dcgd_closure():
        pred_res = model(x_res, t_res)
        pred_left = model(x_left, t_left)
        pred_right = model(x_right, t_right)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)

        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]

        loss_res = torch.mean((u_t + 50 * u_x) ** 2)
        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)
        loss_ic = torch.mean((pred_left[:,0] - torch.sin(x_left[:,0])) ** 2)

        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])
        loss_bd = loss_bc+loss_ic
        
        return [loss_res, loss_bd]

    if i % 10 == 0:
        losses = dcgd_closure()
        weights = weight_optimizer.step(losses)
        wr = weights[0]
        wb = weights[1]

    optim.step(closure)


In [None]:
print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))

In [None]:
plt.plot(np.array(loss_track)[:,0], label='res')
plt.plot(np.array(loss_track)[:,1]+np.array(loss_track)[:,1], label='bc')
#plt.plot(np.array(loss_track)[:,2], label='ic')
plt.legend()
plt.yscale('log', base=10)
plt.show()

In [None]:
# Visualize PINNsformer
res_test = make_time_sequence(res_test, num_step=5, step=1e-4) 
res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)


In [None]:
# Visualize PINNsformer

x_test, t_test = res_test[:,:,0:1], res_test[:,:,1:2]

with torch.no_grad():
    pred = model(x_test, t_test)[:,0:1]
    pred = pred.cpu().detach().numpy()

pred = pred.reshape(101,101)

mat = scipy.io.loadmat('./pinnsformer/demo/convection/convection.mat')
u = mat['u'].reshape(101,101)

rl1 = np.sum(np.abs(u-pred)) / np.sum(np.abs(u))
rl2 = np.sqrt(np.sum((u-pred)**2) / np.sum(u**2))

print('relative L1 error: {:4f}'.format(rl1))
print('relative L2 error: {:4f}'.format(rl2))

plt.figure(figsize=(6,5))
plt.imshow(pred, extent=[0,np.pi*2,1,0], aspect='auto')

cbar = plt.colorbar()

cbar.ax.tick_params(labelsize=15)

plt.tick_params(axis='both', labelsize=20)
plt.xlabel(r'$x$', fontsize=30)
plt.ylabel(r'$t$', fontsize=30)

#plt.title('Predicted u(x,t)')
#plt.colorbar()
plt.tight_layout()
plt.savefig('./convection_pinnsformer_pred.pdf', format='pdf', bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(6,5))
plt.imshow(u, extent=[0,np.pi*2,1,0], aspect='auto')
plt.xlabel(r'$x$', fontsize=30)
plt.ylabel(r'$t$', fontsize=30)

plt.tick_params(axis='both', labelsize=20)
#plt.title('Exact u(x,t)')

cbar = plt.colorbar()

cbar.ax.tick_params(labelsize=15)

plt.tight_layout()
plt.savefig('./convection_exact.pdf', format='pdf', bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(6,5))
plt.imshow(np.abs(pred - u), extent=[0,np.pi*2,1,0], aspect='auto')
plt.xlabel(r'$x$', fontsize=30)
plt.ylabel(r'$t$', fontsize=30)

plt.tick_params(axis='both', labelsize=20)
#plt.title('Absolute Error')

cbar = plt.colorbar()

cbar.ax.tick_params(labelsize=15)
plt.tight_layout()

plt.savefig('./convection_dcgd_error.pdf', format='pdf', bbox_inches='tight')
plt.show()