In [2]:
# Cell 1
import math
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8,5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


In [4]:
# Cell 2: Burgers PDE solver (pseudo-spectral, periodic BC) + GRF initial condition generator
# We'll solve u_t + u u_x = nu u_xx on x in [0,1], with periodic BC, integrate to t=1 using RK4.

# Generate the initial value
def sample_grf(n_points, lengthscale=0.05, scale=1.0, seed=None):
    if seed is not None:
        np.random.seed(seed)
    k = np.fft.fftfreq(n_points, d=1.0/n_points) 
    power = np.exp(-(2*np.pi*k*lengthscale)**2)
    phases = np.random.normal(size=(n_points,)) + 1j * np.random.normal(size=(n_points,))
    ft = phases * np.sqrt(power)
    x = np.fft.ifft(ft).real
    x = x - x.mean()
    x = (x / x.std()) * scale
    return x

# Integrate Burgers with pseudospectral and RK4. Return u(x,1)
def burgers_pseudospectral(u0, nu, n_steps=200, dt=None):
    N = u0.size
    x = np.linspace(0,1,N,endpoint=False)
    k = 2*np.pi*np.fft.fftfreq(N, d=1.0/N)
    k2 = k**2
    if dt is None:
        dt = 1.0/n_steps
    u = u0.copy()
    def rhs(u):
        uhat = np.fft.fft(u)
        ux = np.fft.ifft(1j * k * uhat).real
        nonlinear = - u * ux
        visc = nu * np.fft.ifft(-k2 * uhat).real
        return nonlinear + visc

    t = 0.0
    for _ in range(n_steps):
        k1 = rhs(u)
        k2v = rhs(u + 0.5*dt*k1)
        k3 = rhs(u + 0.5*dt*k2v)
        k4 = rhs(u + dt*k3)
        u = u + (dt/6.0)*(k1 + 2*k2v + 2*k3 + k4)
        t += dt
    return u


In [5]:
# Cell 3 generate training dataset

def make_dataset(n_train=1000, n_test=200, n_points=1024, nu=1e-2, grf_lenscale=0.03):
    X_train = []
    Y_train = []
    X_test = []
    Y_test = []
    total = n_train + n_test
    print("Generating dataset (this can take a while)... total samples:", total)
    for i in tqdm(range(total)):
        u0 = sample_grf(n_points, lengthscale=grf_lenscale, scale=1.0)
        u1 = burgers_pseudospectral(u0, nu, n_steps=400)  
        if i < n_train:
            X_train.append(u0.astype(np.float32))
            Y_train.append(u1.astype(np.float32))
        else:
            X_test.append(u0.astype(np.float32))
            Y_test.append(u1.astype(np.float32))
    X_train = np.stack(X_train)
    Y_train = np.stack(Y_train)
    X_test = np.stack(X_test)
    Y_test = np.stack(Y_test)
    return X_train, Y_train, X_test, Y_test

X_train, Y_train, X_test, Y_test = make_dataset(
    n_train=1000, n_test=200, n_points=1024
)


Generating dataset (this can take a while)... total samples: 1200


  0%|          | 0/1200 [00:00<?, ?it/s]

  nonlinear = - u * ux
  return ufunc(a, fct, axes=[(axis,), (), (axis,)], out=out)
100%|██████████| 1200/1200 [03:18<00:00,  6.04it/s]


In [6]:
# Cell 4 FNO in 1D

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes
        self.scale = 1 / (in_channels * out_channels)
        self.weight = nn.Parameter(self.scale * torch.randn(in_channels, out_channels, self.modes, 2))

    def compl_mul1d(self, input, weight):
        a, b = input[...,0], input[...,1]
        c, d = weight[...,0], weight[...,1]
        real = torch.einsum("bim, iom -> bom", a, c) - torch.einsum("bim, iom -> bom", b, d)
        imag = torch.einsum("bim, iom -> bom", a, d) + torch.einsum("bim, iom -> bom", b, c)
        return torch.stack((real, imag), dim=-1)

    def forward(self, x):
        # x: (batch, channels, n)
        batchsize, channels, n = x.shape
        x_ft = torch.fft.rfft(x, dim=-1) 
        out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1], 2, device=x.device)
        modes = min(self.modes, x_ft.shape[-1])
        input_modes = torch.stack((x_ft[..., :modes].real, x_ft[..., :modes].imag), dim=-1) 
        w = self.weight 
        out_modes = self.compl_mul1d(input_modes, w.permute(0,1,2,3))
        out_ft[..., :modes, :] = out_modes
        real = out_ft[...,0]  # (b, out, nfreqs)
        imag = out_ft[...,1]
        complex_ft = torch.complex(real, imag)
        x_out = torch.fft.irfft(complex_ft, n=n, dim=-1)
        return x_out

class FNO1d(nn.Module):
    def __init__(self, modes=16, width=64, depth=4):
        super().__init__()
        self.modes = modes
        self.width = width
        self.depth = depth

        self.input_proj = nn.Linear(1, self.width)
        self.spectral_layers = nn.ModuleList()
        self.w_conv = nn.ModuleList()
        for _ in range(self.depth):
            self.spectral_layers.append(SpectralConv1d(self.width, self.width, modes))
            self.w_conv.append(nn.Conv1d(self.width, self.width, kernel_size=1))

        self.output_proj = nn.Sequential(
            nn.Linear(self.width, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x: (batch, n_points)
        b, n = x.shape
        x = x.unsqueeze(-1)  # (b, n, 1)
        x = self.input_proj(x)  # (b, n, width)
        x = x.permute(0,2,1)  # (b, width, n)
        for spec, w in zip(self.spectral_layers, self.w_conv):
            x1 = spec(x)  # (b, width, n)
            x2 = w(x)  # (b, width, n)
            x = x1 + x2
            x = torch.relu(x)
        x = x.permute(0,2,1)  # (b, n, width)
        x = self.output_proj(x)  # (b, n, 1)
        return x.squeeze(-1)  # (b, n)


In [7]:
# Cell 5

class KFNGD:
    def __init__(self, model, lr=1e-3, damping=1e-2, ema_decay=0.05):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.ema_decay = ema_decay
        self.layer_infos = []
        self._register_layers()

    def _register_layers(self):
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                info = {
                    'name': name, 'module': module,
                    'A': None, 'G': None
                }
                self.layer_infos.append(info)
            elif isinstance(module, nn.Conv1d) and module.kernel_size == (1,):
                info = {
                    'name': name, 'module': module,
                    'A': None, 'G': None
                }
                self.layer_infos.append(info)
        for info in self.layer_infos:
            mod = info['module']
            def forward_hook(module, input, output, info=info):
                info['a'] = input[0].detach().clone()
            mod.register_forward_hook(forward_hook)
            def backward_hook(module, grad_input, grad_output, info=info):
                info['g'] = grad_output[0].detach().clone()
            mod.register_full_backward_hook(backward_hook)

    def step(self, batch_size):
        for info in self.layer_infos:
            mod = info['module']
            a = info.get('a', None)
            g = info.get('g', None)
            if a is None or g is None:
                continue
            if isinstance(mod, nn.Linear):
                a_flat = a.contiguous().reshape(a.shape[0], -1)
                g_flat = g.contiguous().reshape(g.shape[0], -1)
            else:
                a_flat = a.permute(0,2,1).contiguous().reshape(-1, a.shape[1])
                g_flat = g.permute(0,2,1).contiguous().reshape(-1, g.shape[1])
            A_batch = (a_flat.t() @ a_flat) / (a_flat.shape[0])
            G_batch = (g_flat.t() @ g_flat) / (g_flat.shape[0])
            if info['A'] is None:
                info['A'] = A_batch
                info['G'] = G_batch
            else:
                info['A'] = (1.0 - self.ema_decay) * info['A'] + self.ema_decay * A_batch
                info['G'] = (1.0 - self.ema_decay) * info['G'] + self.ema_decay * G_batch

            damping = self.damping
            A_inv = torch.linalg.inv(info['A'] + damping * torch.eye(info['A'].shape[0], device=info['A'].device))
            G_inv = torch.linalg.inv(info['G'] + damping * torch.eye(info['G'].shape[0], device=info['G'].device))

            if w.grad is None:
                continue
            grad_W = w.grad.detach()
            if isinstance(mod, nn.Conv1d) and mod.kernel_size == (1,):
                grad_W_mat = grad_W.squeeze(-1)
            else:
                grad_W_mat = grad_W
            precond = G_inv @ grad_W_mat @ A_inv
            if isinstance(mod, nn.Conv1d) and mod.kernel_size == (1,):
                mod.weight.grad.copy_(precond.unsqueeze(-1))
            else:
                mod.weight.grad.copy_(precond)

            if mod.bias is not None and mod.bias.grad is not None:
                grad_b = mod.bias.grad.detach()
                precond_b = G_inv @ grad_b.unsqueeze(-1)
                mod.bias.grad.copy_(precond_b.squeeze(-1))

            info['a'] = None
            info['g'] = None

        # apply gradient descent step
        with torch.no_grad():
            for p in self.model.parameters():
                if p.grad is None:
                    continue
                p -= self.lr * p.grad


In [8]:
# Cell 6 training FNO, adam, sgd, kf-ngd.
from copy import deepcopy

def relative_l2(u_true, u_pred):
    num = np.linalg.norm((u_true - u_pred).reshape(u_true.shape[0], -1), axis=1)
    den = np.linalg.norm(u_true.reshape(u_true.shape[0], -1), axis=1)
    rel = num / den
    return np.mean(rel)

def train_and_evaluate(model, optimizer_name, train_loader, test_loader, num_epochs=50, lr=1e-3, damping=1e-2, gamma=0.05):
    model = deepcopy(model).to(device)
    criterion = nn.MSELoss()
    logs = {'train_loss': []}
    if optimizer_name.lower() == 'adam':
        opt = optim.Adam(model.parameters(), lr=lr)
        kf = None
    elif optimizer_name.lower() == 'sgd':
        opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        kf = None
    elif optimizer_name.lower() == 'kf-ngd':
        opt = None
        kf = KFNGD(model, lr=lr, damping=damping, ema_decay=gamma)
    else:
        raise ValueError("Unknown optimizer")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        count = 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt_used = opt
            if opt_used is not None:
                opt_used.zero_grad()
            else:
                # if using KF-NGD, still zero grads
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad.zero_()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            if optimizer_name.lower() == 'kf-ngd':
                # use kf step: it will precondition grads and update params
                kf.step(batch_size=xb.shape[0])
            else:
                opt_used.step()
            running_loss += loss.item() * xb.shape[0]
            count += xb.shape[0]
        avg_loss = running_loss / count
        logs['train_loss'].append(avg_loss)

        # evaluate on test every epoch (could be less frequent)
        model.eval()
        with torch.no_grad():
            y_preds = []
            y_trues = []
            for xb_t, yb_t in test_loader:
                xb_t = xb_t.to(device)
                yb_t = yb_t.to(device)
                out = model(xb_t).cpu().numpy()
                y_preds.append(out)
                y_trues.append(yb_t.cpu().numpy())
            y_preds = np.concatenate(y_preds, axis=0)
            y_trues = np.concatenate(y_trues, axis=0)
            rel_l2 = relative_l2(y_trues, y_preds)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_loss:.6e} | Test RelL2: {rel_l2:.6e}")
    return model, logs, rel_l2


In [9]:
# Cell 7

# run quick
n_points = 256
n_train = 200
n_test = 50
num_epochs = 10
batch_size = 16

# full-scale 
# n_points = 1024
# n_train = 1000
# n_test = 200
# num_epochs = 30
# batch_size = 16

# 检查数据是否存在
try:
    X_train, Y_train, X_test, Y_test
except NameError:
    print("please run make_dataset() generate: X_train, Y_train, X_test, Y_test")
else:
    # Generate DataLoader
    train_ds = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(Y_train).float())
    test_ds = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(Y_test).float())
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
    # Create the FNO1d
    base_model = FNO1d(modes=16, width=64, depth=4)
    print(base_model)
    # 运行训练实验（Adam / SGD / KF-NGD）
    results = {}
    for opt_name in ['adam', 'sgd', 'kf-ngd']:
        print(f"\n\n=== Training with {opt_name} ===")
        model_trained, logs, final_rel = train_and_evaluate(
            base_model, opt_name,
            train_loader, test_loader,
            num_epochs=num_epochs,
            lr=1e-3, damping=0.01, gamma=0.05
        )
        results[opt_name] = {
            'model': model_trained,
            'logs': logs,
            'final_rel': final_rel
        }

    print("\n=== Training finished ===")

FNO1d(
  (input_proj): Linear(in_features=1, out_features=64, bias=True)
  (spectral_layers): ModuleList(
    (0-3): 4 x SpectralConv1d()
  )
  (w_conv): ModuleList(
    (0-3): 4 x Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  )
  (output_proj): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)


=== Training with adam ===
Epoch 1/10 | Train Loss: nan | Test RelL2: nan
Epoch 2/10 | Train Loss: nan | Test RelL2: nan
Epoch 3/10 | Train Loss: nan | Test RelL2: nan
Epoch 4/10 | Train Loss: nan | Test RelL2: nan
Epoch 5/10 | Train Loss: nan | Test RelL2: nan
Epoch 6/10 | Train Loss: nan | Test RelL2: nan
Epoch 7/10 | Train Loss: nan | Test RelL2: nan
Epoch 8/10 | Train Loss: nan | Test RelL2: nan
Epoch 9/10 | Train Loss: nan | Test RelL2: nan
Epoch 10/10 | Train Loss: nan | Test RelL2: nan


=== Training with sgd ===
Epoch 1/10 | Train Loss: nan | Test RelL2: nan
Epoch 2/10 | Train Los

  loss.backward()


In [28]:
# Cell 8 graph

def plot_results(results):
    plt.figure()
    for name, info in results.items():
        logs = info['logs']
        plt.plot(logs['train_loss'], label=name)
    plt.xlabel("Epoch")
    plt.ylabel("Training MSE Loss")
    plt.yscale("log")
    plt.legend()
    plt.title("Training Loss vs Epochs")
    plt.show()

    for name, info in results.items():
        print(f"{name} final test Relative L2: {info['final_rel']:.6e}")

if 'results' in globals():
    plot_results(results)
else:
    print("No results to plot. Run training first.")


No results to plot. Run training first.


In [1]:
# 清理数据
import gc
import sys

for name in list(globals().keys()):
    if not name.startswith("_") and name not in ["gc", "sys"]:
        del globals()[name]

gc.collect()

try:
    import psutil, os
    process = psutil.Process(os.getpid())
    print(f"Current memory usage: {process.memory_info().rss / 1024 ** 2:.2f} MB")
except ImportError:
    print("psutil not installed; skipped memory usage check.")

print("Memory cleanup done!")


Current memory usage: 77.76 MB
Memory cleanup done!
