In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

In [7]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)
    def forward(self, x):
        h = torch.relu(self.fc1(x))
        out = self.fc2(h)
        return out, h

In [10]:
class KFNGD:
    def __init__(self, model, lr=1e-3, gamma=0.05, damping=1e-3):
        self.model = model
        self.lr = lr
        self.gamma = gamma
        self.damping = damping
        self.A_hat = None
        self.G_hat = None

    def step(self, loss, activations):
        loss.backward()
        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.view(-1))
        g_bar = torch.cat(grads).detach()
        a_bar = activations.detach()

        # 计算 batch 协方差 (这里只写单样本版；batch 需按公式求和后 /B)
        A_batch = torch.outer(a_bar, a_bar)
        G_batch = torch.outer(g_bar, g_bar)

        if self.A_hat is None:
            self.A_hat, self.G_hat = A_batch, G_batch
        else:
            self.A_hat = (1-self.gamma)*self.A_hat + self.gamma*A_batch
            self.G_hat = (1-self.gamma)*self.G_hat + self.gamma*G_batch
            
        I_A = torch.eye(self.A_hat.shape[0])
        I_G = torch.eye(self.G_hat.shape[0])
        A_inv = torch.linalg.inv(self.A_hat + self.damping * I_A)
        G_inv = torch.linalg.inv(self.G_hat + self.damping * I_G)

        # natural gradient
        a_tilde = A_inv @ a_bar
        g_tilde = G_inv @ g_bar
        g_nat = torch.kron(a_tilde, g_tilde)

        # 更新参数
        offset = 0
        for p in self.model.parameters():
            if p.grad is not None:
                n = p.numel()
                p.data -= self.lr * g_nat[offset:offset+n].view_as(p)
                offset += n

        self.model.zero_grad()

In [11]:
torch.manual_seed(0)
net = SimpleNet()
opt = KFNGD(net, lr=1e-3)

for step in range(200): 
    x = torch.randn(4, 10)
    y = torch.randn(4, 1)
    out, h = net(x)
    loss = ((out - y) ** 2).mean()

    opt.step(loss, h.mean(0))

    print(f"step {step} loss={loss.item():.4f}")

step 0 loss=1.8181
step 1 loss=0.8767
step 2 loss=0.2927
step 3 loss=0.8811
step 4 loss=0.5779
step 5 loss=0.6285
step 6 loss=1.7943
step 7 loss=0.9657
step 8 loss=1.6610
step 9 loss=1.2665
step 10 loss=0.5698
step 11 loss=0.9160
step 12 loss=0.7589
step 13 loss=0.3513
step 14 loss=0.2464
step 15 loss=0.3675
step 16 loss=0.3665
step 17 loss=1.0077
step 18 loss=0.2663
step 19 loss=0.3218
step 20 loss=1.6320
step 21 loss=0.3754
step 22 loss=0.8376
step 23 loss=0.7988
step 24 loss=1.0424
step 25 loss=1.7393
step 26 loss=0.1959
step 27 loss=0.6837
step 28 loss=1.1714
step 29 loss=2.7500
step 30 loss=0.8828
step 31 loss=1.2167
step 32 loss=1.2394
step 33 loss=2.4025
step 34 loss=1.1102
step 35 loss=0.4996
step 36 loss=0.2793
step 37 loss=0.9932
step 38 loss=1.8271
step 39 loss=1.2606
step 40 loss=0.6624
step 41 loss=1.0632
step 42 loss=0.7606
step 43 loss=0.8815
step 44 loss=1.5093
step 45 loss=3.5758
step 46 loss=1.8410
step 47 loss=1.2291
step 48 loss=0.7034
step 49 loss=0.1855
step 50 lo