In [10]:
#!/usr/bin/env python
# coding: utf-8

import os
import time
import numpy as np
import sys
sys.path.append("..")         # 添加上级目录到路径，用于导入自定义模块
from pinn import *            # 导入PINN基础模型
from grad_stats import *      # 导入梯度统计计算模块
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.autograd import grad
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR, ExponentialLR, MultiStepLR
from scipy.interpolate import griddata
from pyDOE import lhs         # 拉丁超立方采样
import scipy.io               # 用于加载.mat数据文件

In [11]:
# 设置设备（GPU或CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def sampler(num_r, num_b, num_0):
    """数据采样函数 - 生成训练所需的各类点"""

    # 定义空间域的下界和上界
    lb = np.array([-1.0])
    ub = np.array([1.0])

    # 加载Allen-Cahn方程的数据
    data = scipy.io.loadmat('../data/AC.mat')

    # 提取时间、空间坐标和精确解
    t = data['tt'].flatten()[:,None] 
    x = data['x'].flatten()[:,None] 
    Exact = data['uu']
    Exact_u = np.real(Exact)    # 取实部

    #grab training points from domain
    # 从初始条件中随机选择训练点
    idx_x = np.random.choice(x.shape[0], num_0, replace=False)       #获得索引
    x0 = x[idx_x,:]                                                  #取值
    u0 = Exact_u[idx_x,0:1]     # t=0时刻的初始条件 

    # 从时间域中随机选择边界点
    idx_t = np.random.choice(t.shape[0], num_b, replace=False)       #获得索引
    tb = t[idx_t,:]                                                  #取值
    
    # Grab collocation points using latin hpyercube sampling
    # 使用拉丁超立方采样选择配点（用于PDE残差计算）
    X_f = lb + (ub-lb)*lhs(2, num_r)
    X_f[:,1:2] = np.abs(X_f[:,1:2])    # 确保时间为正

    # 构建各类训练点集（位置、时间数据拼接）
    X0 = np.concatenate((x0, np.abs(0*x0)), 1) # 初始条件点(x0, 0) 
    X_lb = np.concatenate((0*tb + lb[0], tb), 1) # 左边界点(lb[0], tb)  
    X_ub = np.concatenate((0*tb + ub[0], tb), 1) # 右边界点(ub[0], tb) 
    
    
    #generate meshgrid
    # 生成用于验证的网格点
    X, T = np.meshgrid(x,t)      #将位置x时间t组合成二维网格
    
    X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))  #将空间时间对压缩成一维数组
    u_sol = Exact_u.T.flatten()[:,None]    #转置后与网格对齐，并展平为一维数组，对应X_star

    # 计算输入数据的均值和标准差（用于数据标准化）
    X_mean = torch.tensor(np.mean(np.concatenate([X_f, X_lb, X_ub, X0], 0), axis=0, keepdims=True), dtype=torch.float32, device=device)
    
    X_std  = torch.tensor(np.std(np.concatenate([X_f, X_lb, X_ub, X0], 0), axis=0, keepdims=True), dtype=torch.float32, device=device)

    # 将数据转换为PyTorch张量
    X_train = torch.tensor(X_f, dtype=torch.float32, requires_grad=True,device=device)  # 配点，需要梯度
    
    X_lb = torch.tensor(X_lb, dtype=torch.float32, device=device, requires_grad=True)   # 左边界点
    X_rb = torch.tensor(X_ub, dtype=torch.float32, device=device, requires_grad=True)   # 右边界点
    
    X_ic = torch.tensor(X0, dtype=torch.float32, device=device)    # 初始条件点
    
    # compute mean and std of training data
    U_ic = torch.tensor(u0, dtype=torch.float32, device=device)    # 初始条件值
    
    return X_train, X_lb, X_rb, X_ic, U_ic, X_mean, X_std, X, T, Exact_u, X_star, u_sol



###### computes pde residual
def AC_res(uhat, data): 
    """计算Allen-Cahn方程的PDE残差"""
    x = data[:,0:1]      # 空间坐标
    t = data[:,1:2]      # 时间坐标

    # 计算一阶导数
    du = grad(outputs=uhat, inputs=data, 
              grad_outputs=torch.ones_like(uhat), create_graph=True)[0]
    
    dudx = du[:,0:1]     # 对x的一阶偏导
    dudt = du[:,1:2]     # 对t的一阶偏导

    # 计算二阶导数（对x的二阶偏导）
    dudxx = grad(outputs=dudx, inputs=data, 
              grad_outputs=torch.ones_like(uhat), create_graph=True)[0][:,0:1]

    # Allen-Cahn方程残差: u_t - ε*u_xx + f(u) = 0
    # 这里 ε=0.0001, f(u)=5u^3-5u
    residual = dudt - 0.0001*dudxx + 5*uhat**3 - 5*uhat
    
    return residual


def AC_res_u_x(uhat, data): #data=(x,t)
    """计算解对x的一阶偏导（用于周期性边界条件）"""
    du = grad(outputs=uhat, inputs=data, 
              grad_outputs=torch.ones_like(uhat), create_graph=True)[0]
    
    dudx = du[:,0:1]    # 对x的一阶偏导
    return dudx     

In [13]:
# 训练参数设置
i_print = 100       # 打印间隔
all_losses=[]       # 存储所有损失
list_of_l2_Errors=[]   # 存储L2误差

lr = 1e-3           # 学习率
mm         = 100    # 权重更新频率
alpha_ann  = 0.1    # 传统方法的权重更新系数
n_epochs   = 300  # 总训练轮数

layer_sizes =  [2, 128, 128, 128, 128, 1]    # 网络结构 [输入, 隐藏层..., 输出]
num_0 = 512        # 初始条件点数
num_b = 100   #actually N_b is 200 for lb and ub    # 单边边界条件点数（实际总边界点数为2*num_b）
num_r = 20000      # 配点数

# 学习率调度策略
guding_lr = True
if guding_lr:
    path_loc= './results/guding_lr_%s_rb0_%s_%s_%s_iter_%s_%s' % (lr, num_r, num_b, num_0, n_epochs, layer_sizes) 
else:
    path_loc= './results/step_lr_%s_rb0_%s_%s_%s_iter_%s_%s' % (lr, num_r, num_b, num_0, n_epochs, layer_sizes) 

# 创建结果保存目录
print('guding_lr, lr: ', guding_lr, lr)
print('num_r, num_b, num_0: ', num_r, num_b, num_0)
print('layer_sizes: ', layer_sizes)

if not os.path.exists(path_loc):
    os.makedirs(path_loc)


# 定义要比较的方法列表
# 0: 等权重PINN; 1: GW-PINN (mean); 2: GW-PINN (std); 3: GW-PINN (kurtosis)
# DB_PINN_mean/std/kurt: DB-PINN的三种变体
method_list = [0, 1, 2, 3, 'DB_PINN_mean', 'DB_PINN_std', 'DB_PINN_kurt']
#0: vanilla PINN (Equal Weighting); GW-PINN: 1: mean (max/avg); 2: std; 3: kurtosis;  

guding_lr, lr:  True 0.001
num_r, num_b, num_0:  20000 100 512
layer_sizes:  [2, 128, 128, 128, 128, 1]


In [None]:
print(time.time())
# 主训练循环
for i in range(7):   # 遍历7种方法
    method = method_list[i]
    for j in range(1):   # 每种方法运行1次（可扩展为多次取平均）
        
        print('i, j, method: ', i, j, method)
        save_loc = path_loc + '/method_' + str(method) + '/run_' + str(j) 
        if not os.path.exists(save_loc):
            os.makedirs(save_loc)
        
        extras=str(num_r)+ "+"+ str(num_b) + "+" + str(num_0)
        print("#######Training with#####\n",extras)

        # 采样训练数据
        X_train, X_lb, X_rb, X_ic, U_ic, X_mean, X_std, X, T, Exact_u, X_star, u_sol = sampler(num_r, num_b, num_0)
        
        # 初始化PINN网络
        net = PINN(sizes=layer_sizes, mean=X_mean, std=X_std, activation=torch.nn.Tanh()).to(device)

        # 初始化损失权重（DB-PINN论文中的λ^i）
        lambd_r       = torch.ones(1, device=device)
        lambd_bc       = torch.ones(1, device=device)
        lambd_ic       = torch.ones(1, device=device) 

        # 记录权重历史
        lambd_r_all       = [];
        lambd_bc_all      = [];
        lambd_ic_all      = [];

        # 记录损失和误差历史
        losses = []
        losses_initial  = [];
        losses_boundary  = [];
        losses_residual = [];
        l2_error = []
        
        N_l = 0          # 权重更新计数器（用于DB-PINN的在线平均计算）
        
        # 优化器设置
        params = [{'params': net.parameters(), 'lr': lr}]  
        milestones = [[20000,40000,60000]]       # 学习率调整的里程碑
        
        if guding_lr:
            optimizer = Adam(params)    # 固定学习率
        else:
            optimizer = Adam(params) 
            scheduler = MultiStepLR(optimizer, milestones[0], gamma=0.95)     # 多步学习率衰减
        
        print("training with shape of residual points: ", X_train.size())
        print("training with shape of boundary points (*2): ", X_lb.size())
        print("training with shape of initial points: ", X_ic.size())
        
        start_time = time.time()

         # 训练循环
        for epoch in range(n_epochs): 

            # 前向传播：计算各损失项
            uhat  = net(X_train)  # 在配点处的预测
            res   = AC_res(uhat, X_train)   # PDE残差
            l_reg = torch.mean((res)**2)    # PDE残差损失

            # 边界条件计算
            predl = net(X_lb)     # 左边界预测
            predr = net(X_rb)     # 右边界预测
            
            predl_dx = AC_res_u_x(predl, X_lb)         # 左边界一阶导
            predr_dx = AC_res_u_x(predr, X_rb)         # 右边界一阶导

            # 周期性边界条件损失：函数值和一阶导都匹配
            l_bc  = torch.mean((predl - predr)**2)  
            l_bc += torch.mean((predl_dx - predr_dx)**2)

            # 初始条件损失
            pred_ic = net(X_ic)
            l_ic = torch.mean((pred_ic - U_ic)**2)

            # 组合所有条件损失（用于DB-PINN的难度指数计算）
            L_t = torch.stack((l_reg, l_bc, l_ic))
            
            # 权重更新部分（每mm个epoch更新一次）
            with torch.no_grad():
                if epoch % mm == 0:
                    N_l += 1    # 更新计数器

                    # 计算各损失的梯度统计量（用于权重计算）
                    stdr,kurtr=loss_grad_stats(l_reg, net)        # PDE残差的梯度的标准差和峰度
                    stdb,kurtb=loss_grad_stats(l_bc, net)         # 边界条件的梯度统计
                    stdi,kurti=loss_grad_stats(l_ic, net)         # 初始条件的梯度统计
                    
                    maxr,meanr=loss_grad_max_mean(l_reg, net)     # PDE残差的梯度最大值和均值
                    maxb,meanb=loss_grad_max_mean(l_bc, net,lambg=lambd_bc)     # 边界条件的梯度统计
                    maxi,meani=loss_grad_max_mean(l_ic, net,lambg=lambd_ic)     # 初始条件的梯度统计

                    # 初始化DB-PINN的在线平均变量
                    if epoch == 0:
                        lam_avg_bc = torch.zeros(1, device=device)         # 边界权重的运行平均
                        lam_avg_ic = torch.zeros(1, device=device)         # 初始权重的运行平均
                        running_mean_L = torch.zeros(1, device=device)     # 损失向量的运行平均

                    # 方法1: GW-PINN (mean) - 基于梯度均值比
                    if method == 1:
                        # max/avg    # max/avg 方法：PDE梯度最大值 / 条件梯度均值
                        lamb_hat = maxr/meanb
                        lambd_bc     = (1-alpha_ann)*lambd_bc + alpha_ann*lamb_hat    # EMA更新
                        lamb_hat = maxr/meani
                        lambd_ic     = (1-alpha_ann)*lambd_ic + alpha_ann*lamb_hat 

                    # 方法2: GW-PINN (std) - 基于梯度标准差比（逆Dirichlet加权）
                    elif method == 2:
                        # inverse dirichlet
                        lamb_hat = stdr/stdb
                        lambd_bc     = (1-alpha_ann)*lambd_bc + alpha_ann*lamb_hat
                        lamb_hat = stdr/stdi
                        lambd_ic     = (1-alpha_ann)*lambd_ic + alpha_ann*lamb_hat

                    # 方法3: GW-PINN (kurtosis) - 基于梯度峰度比
                    elif method == 3:
                        # kurtosis based weighing    
                        covr= stdr/kurtr     # 标准差/峰度
                        covb= stdb/kurtb
                        covi= stdi/kurti
                        lamb_hat = covr/covb
                        lambd_bc     = (1-alpha_ann)*lambd_bc + alpha_ann*lamb_hat
                        lamb_hat = covr/covi
                        lambd_ic     = (1-alpha_ann)*lambd_ic + alpha_ann*lamb_hat

                    # DB-PINN (mean变体) - 结合inter-balancing和intra-balancing
                    elif method == 'DB_PINN_mean':
                        # Inter-balancing: 计算总梯度比率G（公式4）
                        hat_all = maxr/meanb + maxr/meani

                        # Intra-balancing: 使用Welford算法更新损失均值（公式7）
                        mean_param = (1. - 1 / N_l)
                        running_mean_L = mean_param * running_mean_L + (1 - mean_param) * L_t.detach()

                        # 计算难度指数（公式5）：当前损失/历史平均损失
                        l_t_vector = L_t/running_mean_L

                        # 权重分配（公式6）：按难度指数比例分配聚合权重
                        hat_bc = hat_all* l_t_vector[1]/(l_t_vector[1] + l_t_vector[2])
                        hat_ic = hat_all* l_t_vector[2]/(l_t_vector[1] + l_t_vector[2])
                        lambd_bc = lam_avg_bc + 1/N_l*(hat_bc - lam_avg_bc)
                        lambd_ic = lam_avg_ic + 1/N_l*(hat_ic - lam_avg_ic)
                        lam_avg_bc = lambd_bc
                        lam_avg_ic = lambd_ic

                    # DB-PINN (std变体)
                    elif method == 'DB_PINN_std':  
                        hat_all = stdr/stdb + stdr/stdi         # 使用标准差统计量
                        
                        mean_param = (1. - 1 / N_l)
                        running_mean_L = mean_param * running_mean_L + (1 - mean_param) * L_t.detach()
                        l_t_vector = L_t/running_mean_L
                        hat_bc = hat_all* l_t_vector[1]/(l_t_vector[1] + l_t_vector[2])
                        hat_ic = hat_all* l_t_vector[2]/(l_t_vector[1] + l_t_vector[2])
                        lambd_bc = lam_avg_bc + 1/N_l*(hat_bc - lam_avg_bc)
                        lambd_ic = lam_avg_ic + 1/N_l*(hat_ic - lam_avg_ic)
                        lam_avg_bc = lambd_bc
                        lam_avg_ic = lambd_ic

                    # DB-PINN (kurtosis变体)
                    elif method == 'DB_PINN_kurt':  
                        covr= stdr/kurtr
                        covb= stdb/kurtb
                        covi= stdi/kurti
                        hat_all = covr/covb + covr/covi          # 使用峰度统计量
                        
                        mean_param = (1. - 1 / N_l)
                        running_mean_L = mean_param * running_mean_L + (1 - mean_param) * L_t.detach()
                        l_t_vector = L_t/running_mean_L
                        hat_bc = hat_all* l_t_vector[1]/(l_t_vector[1] + l_t_vector[2])
                        hat_ic = hat_all* l_t_vector[2]/(l_t_vector[1] + l_t_vector[2])
                        lambd_bc = lam_avg_bc + 1/N_l*(hat_bc - lam_avg_bc)
                        lambd_ic = lam_avg_ic + 1/N_l*(hat_ic - lam_avg_ic)
                        lam_avg_bc = lambd_bc
                        lam_avg_ic = lambd_ic
                            
                    else:
                        # equal weighting       # 方法0: 等权重PINN（基线方法）
                        lambd_bc = torch.ones(1, device=device)
                        lambd_ic = torch.ones(1, device=device)

            # 计算加权总损失（DB-PINN论文中的公式2）
            loss = l_reg + lambd_bc.item()*l_bc + lambd_ic.item()*l_ic

            # 定期输出训练信息
            if epoch%i_print==0:

                # 计算验证集上的L2相对误差
                inp = torch.tensor(X_star, dtype=torch.float32, device=device, requires_grad=True)
                out = net(inp).cpu().data.numpy().reshape(u_sol.shape)
                tmp = np.linalg.norm(out.reshape(-1)-u_sol.reshape(-1))/np.linalg.norm(out.reshape(-1))

                 # 记录各种指标
                l2_error.append(tmp)
                list_of_l2_Errors.append(tmp)
                all_losses.append(loss.item())
                
                losses_initial.append(l_ic.item())
                losses_boundary.append(l_bc.item())
                losses_residual.append(l_reg.item())
               
                lambd_r_all.append(lambd_r.item())
                lambd_bc_all.append(lambd_bc.item())
                lambd_ic_all.append(lambd_ic.item())

                # 打印训练状态
                print("method={}, epoch {}/{}, loss={:.4f}, loss_r={:.6f}, loss_bc={:.6f}, loss_ic={:.6f}, lam_r={:.4f}, lam_bc={:.4f}, lam_ic={:.4f}, lr={:.5f}, l2_error(%)={:.3f}".format(method, epoch+1, n_epochs, loss.item(), l_reg.item(), l_bc.item(), l_ic.item(), lambd_r.item(), lambd_bc.item(), lambd_ic.item(), optimizer.param_groups[0]['lr'], tmp*100)) 
            
            # 反向传播和参数更新
            optimizer.zero_grad()
            loss.backward()
            if guding_lr:
                optimizer.step()      # 固定学习率
            else:
                optimizer.step()
                scheduler.step()      # 学习率调度

        # 训练结束，计算最终性能
        elapsed_time = time.time() - start_time
        inp = torch.tensor(X_star, dtype=torch.float32, device=device, requires_grad=True )
        out = net(inp)
        out = out.cpu().data.numpy().reshape(u_sol.shape)

        print("\n.....\n")
        print("Method: , j: ",method, j)
        print("pred rel. l2-error = {:e}\n".format(np.linalg.norm(out.reshape(-1)-u_sol.reshape(-1))/np.linalg.norm(u_sol.reshape(-1))))
        print("pred abs. error = {:e}\n".format(np.mean(np.abs(out.reshape(-1)-u_sol.reshape(-1)))))
        print("\n.....\n")
        
        # 插值得到网格上的预测结果
        U_pred = griddata(X_star, out.flatten(), (X, T), method='cubic')
        
        ############plot results
        fig = plt.figure(1, figsize=(18, 5))
        fig_1 = plt.subplot(1, 3, 1)
        plt.pcolor(X, T, Exact_u.T, cmap='jet')
        plt.colorbar()
        plt.xlabel(r'$x$')
        plt.ylabel(r'$t$')
        plt.title('Exact $u(x)$')
        fig_2 = plt.subplot(1, 3, 2)
        plt.pcolor(X, T, U_pred, cmap='jet')
        plt.colorbar()
        plt.xlabel(r'$x$')
        plt.ylabel(r'$t$')
        plt.title('Predicted $u(x)$')
        fig_3 = plt.subplot(1, 3, 3)
        plt.pcolor(X, T, np.abs(Exact_u.T - U_pred), cmap='jet')
        plt.colorbar()
        plt.xlabel(r'$x$')
        plt.ylabel(r'$t$')
        plt.title('Absolute error')
        plt.tight_layout()
        plt.savefig(os.path.join(save_loc,'1.predictions.png'))
        #plt.show()
        plt.close()
        
        
        fig_2 = plt.figure(2)
        ax = fig_2.add_subplot(1, 1, 1)
        ax.plot(losses_residual, label='$\mathcal{L}_{r}$')
        ax.plot(losses_boundary, label='$\mathcal{L}_{bc}$')
        ax.plot(losses_initial, label='$\mathcal{L}_{ic}$')        
        ax.set_yscale('log')
        ax.set_xlabel('iterations')
        ax.set_ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_loc, '2.loss.png'))
        plt.show()
        plt.close()
        
        fig_3 = plt.figure(3)
        ax = fig_3.add_subplot(1, 1, 1)
        ax.plot(lambd_bc_all, label='$\lambda_{bc}$')
        ax.plot(lambd_ic_all, label='$\lambda_{ic}$')
        ax.set_xlabel('iterations')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_loc,'3.learned_weights.png'))
        plt.show()
        plt.close()
        
        
        fig_4 = plt.figure(4)
        ax = fig_4.add_subplot(1, 1, 1)
        ax.plot(l2_error)
        ax.set_xlabel('iterations')
        plt.tight_layout()
        plt.savefig(os.path.join(save_loc,'4.L2_error.png'))
        plt.show()
        plt.close()
        