# Correction Based approach
Given a FD solution $\bar{y}$, the propagation field $c$, predict the PS solution $y$. We try four main approaches:
- Given $(\bar{y}, c)$, design a NN that outputs $\hat{y} = \bar{y} + e$
- Given $(\bar{y})$, design a NN that outputs $\hat{y} = \bar{y} + e$
- Given $(\bar{y}, c)$, design a NN that outputs $\hat{y} = f(\bar{y})$
- Given $(\bar{y})$, design a NN that outputs $\hat{y} = f(\bar{y})$

The four approaches are tested using train/val/test datasets, on a MLP and a CNN. The loss function is the normalized mean squared error, and the optimizer is AdamW. For fairness, all the architectures are trained on 26000 samples, validated on 1000 and then tested on 1000. Each model is trained for 100 epochs with a fixed learning rate.

In the following code, `x` denotes $\bar{y}$ (the _input_) and $y$ denotes $y$ (the expected _output_).

In [1]:
%loadlibs
from utils import load_datasets, rnmse

Loaded libraries:
	- numpy (np)
	- matplotlib.pyplot (plt)
	- torch
	- torch.nn (nn)
	- torch.optim (optim)
	- tqdm


In [8]:
nx, nt = 128, 256
n_epochs = 10
l_r = 5e-5
device = 'mps'

In [3]:
train, val, test, x_min, x_max, c_min, c_max, y_min, y_max = load_datasets(26000, 1000, 1000)

## Preliminary

In [4]:
train_loss = 0
for xb, cb, yb in train:
    xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
    train_loss += rnmse(xb, yb)
print(f"Train set initial loss: {train_loss/len(train):.5f}")

val_loss = 0
for xb, cb, yb in val:
    xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
    val_loss += rnmse(xb, yb)
print(f"Val set initial loss: {val_loss/len(val):.5f}")

test_loss = 0
for xb, cb, yb in test:
    xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
    test_loss += rnmse(xb, yb)
print(f"Test set initial loss: {test_loss/len(test):.5f}")

Train set initial loss: 0.89193
Val set initial loss: 0.89092
Test set initial loss: 0.89718


## 1. Correction term, using $\bar{y}$ and $c$
The first approach designs a neural network outputing a correction term added to the FD solution.

In [9]:
class CorrectionMLP(nn.Module):
    def __init__(self, nx, hidden=64):
        super().__init__()
        self.c_encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(nx * nx, hidden*2),
            nn.Tanh(),
            nn.Linear(hidden*2, hidden),
            nn.Tanh(),
        )
        self.xt_block = nn.Sequential(
            nn.Linear(nx + hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden*2),
            nn.Tanh(),
            nn.Linear(hidden*2, nx),
        )
        self.correction=nn.Sequential(
            nn.Linear(nx, hidden),
            nn.Tanh(),
            nn.Linear(hidden, nx),
        )

    def forward(self, x, c):
        n, nt, nx = x.shape

        c_latent = self.c_encoder(c)
        c_latent = c_latent[:, None, :].expand(n, nt, -1)

        inp = torch.cat([x, c_latent], dim=-1)
        y = self.xt_block(inp)
        y = self.correction(y)
        return y + x

model = CorrectionMLP(128, hidden=256).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=l_r)
best_val_loss = 10
for epoch in (pbar:=tqdm(range(n_epochs))):
    model.train()
    for xb, cb, yb in train:
        xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
        pred = model(xb, cb)
        loss = rnmse(pred, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(loss=loss.item())
    model.eval()
    val_loss = 0
    for xb, cb, yb in val:
        xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
        val_loss += rnmse(model(xb, cb), yb)
    print(f"Validation loss: {val_loss/len(val):.5f}")
    if val_loss/len(val) < best_val_loss:
        best_val_loss = val_loss/len(val)
        torch.save(model, "saved_models/correction_mlp.pt")

model = torch.load("saved_models/correction_mlp.pt", weights_only=False)
test_loss = 0
for xb, cb, yb in test:
    xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
    test_loss += rnmse(model(xb, cb), yb)
print(f"Test loss: {test_loss/len(test):.5f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Validation loss: 0.37236
Validation loss: 0.37308
Validation loss: 0.36848
Validation loss: 0.36648
Validation loss: 0.36276
Validation loss: 0.35888
Validation loss: 0.35443
Validation loss: 0.36026
Validation loss: 0.34547
Validation loss: 0.36359
Test loss: 0.35055


We do the same using a simple CNN.

In [7]:
class CorrectionCNN(nn.Module):
    def __init__(self, nx, hidden=8):
        super().__init__()

        self.c_encoder = nn.Sequential(
            nn.Conv2d(1, hidden//4, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden//4, hidden//2, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden//2, hidden, kernel_size=3, padding=1),
            nn.Tanh(),
        )
        self.c_pool = nn.AdaptiveAvgPool2d(1) 
        self.xt_block = nn.Sequential(
            nn.Conv2d(1 + 1, hidden, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden, hidden//2, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden//2, 1, kernel_size=3, padding=1),
        )
        self.correction = nn.Sequential(
            nn.Conv2d(1, hidden, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden, 1, kernel_size=3, padding=1),
        )

    def forward(self, x, c):
        B, nt, nx = x.shape
        _, cnx1, cnx2 = c.shape
        x_img = x.unsqueeze(1) 
        c_img = c.unsqueeze(1)
        c_latent = self.c_encoder(c_img)
        c_latent = self.c_pool(c_latent)
        c_latent = c_latent.expand(B, -1, nt, nx)
        inp = torch.cat([x_img, c_latent.mean(dim=1, keepdim=True)], dim=1)
        y = self.xt_block(inp)
        y = self.correction(y)
        return y.squeeze(1) + x

model = CorrectionCNN(128, hidden=8).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=l_r)
best_val_loss = 10
for epoch in (pbar:=tqdm(range(n_epochs))):
    model.train()
    for xb, cb, yb in train:
        xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
        pred = model(xb, cb)
        loss = nmse(pred, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(loss=loss.item())
    model.eval()
    val_loss = 0
    for xb, cb, yb in val:
        xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
        val_loss += nmse(model(xb, cb), yb)
    print(f"Validation loss: {val_loss/len(val):.5f}")
    if val_loss/len(val) < best_val_loss:
        best_val_loss = val_loss/len(val)
        torch.save(model, "saved_models/correction_cnn.pt")

model = torch.load("saved_models/correction_cnn.pt", weights_only=False)
test_loss = 0
for xb, cb, yb in test:
    xb, cb, yb = xb.to(device), cb.to(device), yb.to(device)
    test_loss += nmse(model(xb, cb), yb)
print(f"Test loss: {test_loss/len(test):.5f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Validation loss: 10.30507
Validation loss: 1.55229
Validation loss: 1.46175
Validation loss: 1.34683
Validation loss: 1.21439
Validation loss: 1.07467
Validation loss: 0.93890
Validation loss: 0.81581
Validation loss: 0.70908
Validation loss: 0.61688
Test loss: 0.61846


## 2. Correction term, using $\bar{y}$
The second approach designs a neural network outputing a correction term added to the FD solution.

In [8]:
class CorrectionMLP2(nn.Module):
    def __init__(self, nx, hidden=64):
        super().__init__()
        self.xt_block = nn.Sequential(
            nn.Linear(nx, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden//2),
            nn.Tanh(),
            nn.Linear(hidden//2, nx),
        )
        self.correction= nn.Sequential(
            nn.Linear(nx, hidden),
            nn.Tanh(),
            nn.Linear(hidden, nx),
        )

    def forward(self, x):
        n, nt, nx = x.shape
        y = self.xt_block(x)
        y = self.correction(y)
        return y + x

model = CorrectionMLP2(128, hidden=64).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=l_r)
best_val_loss = 10
for epoch in (pbar:=tqdm(range(n_epochs))):
    model.train()
    for xb, _, yb in train:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss = nmse(pred, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(loss=loss.item())
    model.eval()
    val_loss = 0
    for xb, _, yb in val:
        xb, yb = xb.to(device), yb.to(device)
        val_loss += nmse(model(xb), yb)
    print(f"Validation loss: {val_loss/len(val):.5f}")
    if val_loss/len(val) < best_val_loss:
        best_val_loss = val_loss/len(val)
        torch.save(model, "saved_models/correction_mlp_2.pt")

model = torch.load("saved_models/correction_mlp_2.pt", weights_only=False)
test_loss = 0
for xb, _, yb in test:
    xb, yb = xb.to(device), yb.to(device)
    test_loss += nmse(model(xb), yb)
print(f"Test loss: {test_loss/len(test):.5f}")

  0%|          | 0/10 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
class CorrectionCNN2(nn.Module):
    def __init__(self, nx, hidden=8):
        super().__init__()
        self.xt_block = nn.Sequential(
            nn.Conv2d(1, hidden, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden, hidden//2, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden//2, 1, kernel_size=3, padding=1),
        )
        self.correction = nn.Sequential(
            nn.Conv2d(1, hidden, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Conv2d(hidden, 1, kernel_size=3, padding=1),
        )

    def forward(self, x):
        B, nt, nx = x.shape
        x_img = x.unsqueeze(1) 
        y = self.xt_block(x_img)
        y = self.correction(y)
        return y.squeeze(1) + x

model = CorrectionCNN2(128, hidden=8).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=l_r)
best_val_loss = 10
for epoch in (pbar:=tqdm(range(n_epochs))):
    model.train()
    for xb, _, yb in train:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss = nmse(pred, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_postfix(loss=loss.item())
    model.eval()
    val_loss = 0
    for xb, _, yb in val:
        xb, yb = xb.to(device), yb.to(device)
        val_loss += nmse(model(xb), yb)
    print(f"Validation loss: {val_loss/len(val):.5f}")
    if val_loss/len(val) < best_val_loss:
        best_val_loss = val_loss/len(val)
        torch.save(model, "saved_models/correction_cnn_2.pt")

model = torch.load("saved_models/correction_cnn_2.pt", weights_only=False)
test_loss = 0
for xb, _, yb in test:
    xb, yb = xb.to(device), yb.to(device)
    test_loss += nmse(model(xb), yb)
print(f"Test loss: {test_loss/len(test):.5f}")