In [1]:
import torch
import numpy as np

torch.set_default_dtype(torch.float64)

from torch.utils.tensorboard import SummaryWriter
import tqdm

from matplotlib import pyplot as plt

In [2]:
device = torch.device("cuda")

In [3]:
data1 = torch.tensor(np.load("mu_0.9.npy")).to(device)
data2 = torch.tensor(np.load("mu_0.95.npy")).to(device)
data3 = torch.tensor(np.load("mu_1.05.npy")).to(device)
data4 = torch.tensor(np.load("mu_1.1.npy")).to(device)

dt = 2 / 1500

X = torch.cat([data1, data2, data3, data4], axis=0)
x_ref = torch.mean(X, dim=0)
(S, N) = X.shape

X_test = torch.tensor(np.load("mu_1.0.npy")).to(device)

In [4]:
class NRBS(torch.nn.Module):
    def __init__(self, N, n, M1, M2, b):
        super(NRBS, self).__init__()

        mask = torch.zeros(N, M2)
        shift = (M2 - b) / (N - 1)

        for i in range(N):
            mask[i, int(np.ceil(shift*i)): int(np.ceil(shift*i)) + b] = 1

        new_mask = torch.zeros(N, M2)

        for idx in range(N):
            i = idx // 60
            j = idx % 60
            neighbours = [(i, j), (i-1, j), (i+1, j), (i, j+1), (i, j-1)]
            for neighbour_i, neighbour_j in neighbours:
              if (neighbour_i >=0 and neighbour_i < 60) and (neighbour_j >=0 and neighbour_j < 60):
                new_mask[idx] = new_mask[idx] + mask[60*neighbour_i + neighbour_j]

        new_mask[new_mask > 0] = 1

        self.register_buffer('mask', new_mask)


        self.encoder1 = torch.nn.Linear(N, M1)
        self.encoder2 = torch.nn.Linear(M1, n)

        self.decoder1 = torch.nn.Linear(n, M2)
        self.decoder2 = torch.nn.Linear(M2, N)

        torch.nn.init.kaiming_normal_(self.encoder1.weight)
        torch.nn.init.kaiming_normal_(self.encoder2.weight)
        torch.nn.init.kaiming_normal_(self.decoder1.weight)
        torch.nn.init.kaiming_normal_(self.decoder2.weight)

    def encode(self, x):
        x = self.encoder1(x)
        x = x * torch.sigmoid(x)
        x = self.encoder2(x)
        return x

    def decode(self, x):
        x = self.decoder1(x)
        x = x * torch.sigmoid(x)
        x = torch.matmul(x, (self.decoder2.weight * self.get_buffer('mask')).T) + self.decoder2.bias
        return x

    def forward(self, x):
        return self.decode(self.encode(x))


In [15]:
n = 20
nx = 60
ny = 60

nrbs = NRBS(N, n, 6728, 33730, 70).to(device)
nrbs = torch.load('models/shallow_mask_old.pth')

X_tilde = nrbs(X_test - x_ref) + x_ref
l_test = torch.sqrt(torch.sum((X_test - X_tilde) ** 2)) / torch.sqrt(
    torch.sum(X_test**2)
)
l_test

tensor(0.0022, device='cuda:0', grad_fn=<DivBackward0>)

In [6]:
def u_hat_and_u_hat_dot(nrbs, u, u_ref, dt):
    with torch.no_grad():
        u_hat = nrbs.encode(u - u_ref)
        u_hat_dot = (u_hat[1:] - u_hat[:-1]) / dt
        u_hat = u_hat[1:]
    return u_hat, u_hat_dot

In [7]:
data1_hat, data1_hat_dot = u_hat_and_u_hat_dot(nrbs, data1, x_ref, dt)
data2_hat, data2_hat_dot = u_hat_and_u_hat_dot(nrbs, data2, x_ref, dt)
data3_hat, data3_hat_dot = u_hat_and_u_hat_dot(nrbs, data3, x_ref, dt)
data4_hat, data4_hat_dot = u_hat_and_u_hat_dot(nrbs, data4, x_ref, dt)


In [19]:
X_hat = torch.cat([data1_hat, data2_hat, data3_hat, data4_hat])
X_hat_dot = torch.cat([data1_hat_dot, data2_hat_dot, data3_hat_dot, data4_hat_dot])

X_hat_test, X_hat_dot_test = u_hat_and_u_hat_dot(nrbs, X_test, x_ref, dt)

(S, n) = X_hat.shape

In [23]:
X_hat = X_hat.detach().cpu().numpy()
X_hat_dot = X_hat_dot.detach().cpu().numpy()

X_hat_test = X_hat_test.detach().cpu().numpy()
X_hat_dot_test = X_hat_dot_test.detach().cpu().numpy()

In [25]:
D = np.zeros((S, n + n * n))
for i in range(S):
    D[i, 0:n] = X_hat[i]
    D[i, n:] = np.outer(X_hat[i], X_hat[i]).flatten()


In [28]:
res = np.linalg.lstsq(D, X_hat_dot, rcond=1e-11)
AH_red = res[0]
A_red = AH_red[0:n, :].T
H_red = AH_red[n:, :].T

In [29]:
def nr_solve(f, df, x, atol=1e-10, rtol=1e-8, max_itr=50, args=None):
    r = f(x, *args)
    r0_norm = np.linalg.norm(r)
    print("Itr = {:}, residual norm = {:.4E}".format(0, r0_norm))
    if (r0_norm < atol):
        return x
    for i in range(max_itr):
        x = np.linalg.solve(df(x, *args), -r) + x
        r = f(x, *args)
        print("Itr = {:}, residual norm = {:.4E}".format(i + 1, np.linalg.norm(r)))
        if (np.linalg.norm(r) < atol or np.linalg.norm(r) / r0_norm < rtol):
            return x
    raise Exception("solve failed")
        

In [30]:
# replace with NN: u_hat to u_hat_dot
def f(x, x_old, dt):
    s_dot = np.matmul(A_red, x) + np.matmul(H_red, np.outer(x, x).flatten())
    return x - x_old - dt * s_dot

def df(x, x_old, dt):
    return np.eye(n) - (A_red + 2 * np.matmul(H_red.reshape(n, n, n), x)) * dt

In [32]:
x_red = X_hat_test[1]
t = 0
sol = {t : x_red}
dt = 2 / 1500
dt_min = dt / 100
while t < 2:
    print('Time = {:}, dt = {:}'.format(t + dt, dt))
    try:
        x_red = nr_solve(f, df, x_red, args=(x_red, dt))
    except:
        dt = dt / 2
        if (dt < dt_min):
            raise Exception("dt is too small")
        continue
    t = t + dt
    sol[t] = x_red

Time = 0.0013333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 4.6713E-02
Itr = 1, residual norm = 3.8494E-04
Itr = 2, residual norm = 5.6495E-08
Itr = 3, residual norm = 2.7834E-15
Time = 0.0026666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 4.6572E-02
Itr = 1, residual norm = 3.8635E-04
Itr = 2, residual norm = 5.6919E-08
Itr = 3, residual norm = 3.0361E-15
Time = 0.004, dt = 0.0013333333333333333
Itr = 0, residual norm = 4.6423E-02
Itr = 1, residual norm = 3.8739E-04
Itr = 2, residual norm = 5.7388E-08
Itr = 3, residual norm = 3.0612E-15
Time = 0.005333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 4.6283E-02
Itr = 1, residual norm = 3.8746E-04
Itr = 2, residual norm = 5.8157E-08
Itr = 3, residual norm = 3.3276E-15
Time = 0.006666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 4.6142E-02
Itr = 1, residual norm = 3.8603E-04
Itr = 2, residual norm = 5.9513E-08
Itr = 3, residual norm = 4.0093E-15
Time = 0.008,

Exception: dt is too small

In [45]:
t/(2/1500)

15.781250000000016

In [49]:
ref_sol = X_test[15].detach().cpu().numpy()
approximate = (nrbs.decode(torch.tensor(sol[t]).to(device)) + x_ref).detach().cpu().numpy()
np.sqrt(np.sum(approximate - ref_sol)**2) / np.sqrt(np.sum(ref_sol**2))

226.77338523132983

In [35]:
# approximate = (nrbs.decode(sol[t]) + x_ref).detach().cpu().numpy()
# # plot
# ny = 60
# nx = 60
# x = np.load('paper_x.npy')
# y = np.load('paper_y.npy')


# fig_u = plt.figure()
# ax_u = fig_u.gca()
# p_u=ax_u.pcolor(x.reshape(ny,nx), y.reshape(ny,nx), (approximate).reshape(ny,nx))
# cb_u=fig_u.colorbar(p_u,ax=ax_u)
# ax_u.set_xlabel('$x$')
# ax_u.set_ylabel('$y$')
# plt.title('$u$ ($t = {:}$)'.format(t))
# plt.show()

In [22]:
# for t, x_red in sol.items():
#     approximate = (nrbs.decode(x_red) + x_ref).detach().cpu().numpy()
#     # plot
#     ny = 60
#     nx = 60
#     x = np.load('paper_x.npy')
#     y = np.load('paper_y.npy')


#     fig_u = plt.figure()
#     ax_u = fig_u.gca()
#     p_u=ax_u.pcolor(x.reshape(ny,nx), y.reshape(ny,nx), (approximate).reshape(ny,nx))
#     cb_u=fig_u.colorbar(p_u,ax=ax_u)
#     ax_u.set_xlabel('$x$')
#     ax_u.set_ylabel('$y$')
#     plt.title('$u$ ($t = {:}$)'.format(t))
#     plt.show()