In [1]:
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data1 = np.load("mu_0.9.npy")
data2 = np.load("mu_0.95.npy")
data3 = np.load("mu_1.05.npy")
data4 = np.load("mu_1.1.npy")

dt = 2 / 1501

data1_dot = (data1[1:] - data1[:-1]) / dt
data1 = data1[1:]

data2_dot = (data2[1:] - data2[:-1]) / dt
data2 = data2[1:]

data3_dot = (data3[1:] - data3[:-1]) / dt
data3 = data3[1:]

data4_dot = (data4[1:] - data4[:-1]) / dt
data4 = data4[1:]

X = np.concatenate([data1, data2, data3, data4], axis=0)
X_dot = np.concatenate([data1_dot, data2_dot, data3_dot, data4_dot], axis=0)
(S, N) = X.shape

X_test = np.load("mu_1.0.npy")

In [None]:
device = torch.device("cuda")

In [3]:
class NRBS(torch.nn.Module):
    def __init__(self, N, n, M1, M2, b):
        super(NRBS, self).__init__()

        mask = torch.zeros(N, M2)
        shift = (M2 - b) / (N - 1)

        for i in range(N):
            mask[i, int(np.ceil(shift*i)): int(np.ceil(shift*i)) + b] = 1

        new_mask = torch.zeros(N, M2)

        for idx in range(N):
            i = idx // 60
            j = idx % 60
            neighbours = [(i, j), (i-1, j), (i+1, j), (i, j+1), (i, j-1)]
            for neighbour_i, neighbour_j in neighbours:
              if (neighbour_i >=0 and neighbour_i < 60) and (neighbour_j >=0 and neighbour_j < 60):
                new_mask[idx] = new_mask[idx] + mask[60*neighbour_i + neighbour_j]

        new_mask[new_mask > 0] = 1

        self.register_buffer('mask', new_mask)


        self.encoder1 = torch.nn.Linear(N, M1)
        self.encoder2 = torch.nn.Linear(M1, n)

        self.decoder1 = torch.nn.Linear(n, M2)
        self.decoder2 = torch.nn.Linear(M2, N)

        torch.nn.init.kaiming_normal_(self.encoder1.weight)
        torch.nn.init.kaiming_normal_(self.encoder2.weight)
        torch.nn.init.kaiming_normal_(self.decoder1.weight)
        torch.nn.init.kaiming_normal_(self.decoder2.weight)

    def encode(self, x):
        x = self.encoder1(x)
        x = x * torch.sigmoid(x)
        x = self.encoder2(x)
        return x

    def decode(self, x):
        x = self.decoder1(x)
        x = x * torch.sigmoid(x)
        x = torch.matmul(x, (self.decoder2.weight * self.get_buffer('mask')).T) + self.decoder2.bias
        return x

    def forward(self, x):
        return self.decode(self.encode(x))

1.5585333730830057e-05

In [None]:
nrbs = NRBS(N, n, 6728, 33730, 70).to(device)

In [4]:
reduced_X = np.matmul(X, basis)
reduced_X_dot = np.matmul(X_dot, basis)

In [5]:
D = np.zeros((S, n + n * n))
for i in range(S):
    D[i, 0:n] = reduced_X[i]
    D[i, n:] = np.outer(reduced_X[i], reduced_X[i]).flatten()
AH_red = np.linalg.lstsq(D, reduced_X_dot, rcond=None)[0]
A_red = AH_red[0:n, :].T
H_red = AH_red[n:, :].T


In [6]:
def nr_solve(f, df, x, atol=1e-10, rtol=1e-8, max_itr=50, args=None):
    r = f(x, *args)
    r0_norm = np.linalg.norm(r)
    print("Itr = {:}, residual norm = {:.4E}".format(0, r0_norm))
    if (r0_norm < atol):
        return x
    for i in range(max_itr):
        x = np.linalg.solve(df(x, *args), -r) + x
        r = f(x, *args)
        print("Itr = {:}, residual norm = {:.4E}".format(i + 1, np.linalg.norm(r)))
        if (np.linalg.norm(r) < atol or np.linalg.norm(r) / r0_norm < rtol):
            return x
    raise Exception("solve failed")
        

In [7]:
def f(x, x_old, dt):
    s_dot = np.matmul(A_red, x) + np.matmul(H_red, np.outer(x, x).flatten())
    return x - x_old - dt * s_dot

def df(x, x_old, dt):
    return np.eye(n) - (A_red + 2 * np.matmul(H_red.reshape(n, n, n), x)) * dt

In [8]:
x_red = np.matmul(X_test[0], basis)
t = 0
sol = {t : x_red}
dt = 2 / 1500
dt_min = dt / 100
while t < 2:
    print('Time = {:}, dt = {:}'.format(t + dt, dt))
    try:
        x_red = nr_solve(f, df, x_red, args=(x_red, dt))
    except:
        dt = dt / 2
        if (dt < dt_min):
            raise Exception("dt is too small")
        continue
    t = t + dt
    sol[t] = x_red

Time = 0.0013333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.6241E-02
Itr = 1, residual norm = 3.9236E-06
Itr = 2, residual norm = 1.3240E-12
Time = 0.0026666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.6035E-02
Itr = 1, residual norm = 3.9352E-06
Itr = 2, residual norm = 1.3319E-12
Time = 0.004, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.5832E-02
Itr = 1, residual norm = 3.9451E-06
Itr = 2, residual norm = 1.3394E-12
Time = 0.005333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.5634E-02
Itr = 1, residual norm = 3.9534E-06
Itr = 2, residual norm = 1.3464E-12
Time = 0.006666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.5440E-02
Itr = 1, residual norm = 3.9605E-06
Itr = 2, residual norm = 1.3529E-12
Time = 0.008, dt = 0.0013333333333333333
Itr = 0, residual norm = 7.5251E-02
Itr = 1, residual norm = 3.9665E-06
Itr = 2, residual norm = 1.3589E-12
Time = 0.009333333333333334, dt = 0.00133333

Exception: dt is too small

In [10]:
t

0.15643750000000003