In [1]:
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
filelist = Path('../data/ns').glob('*.npy')

reynolds = []
data = []
for filepath in filelist:
    reynolds.append( int(filepath.stem.split('_')[-1]) )
    data.append( np.load(filepath) )

In [3]:
from utils import animate

sequence_id = 0
sequence = data[sequence_id]
Re = reynolds[sequence_id]
u = sequence[:,:,::3]
v = sequence[:,:,1::3]
p = sequence[:,:,2::3]
animate(u)

In [4]:
H, W, T = sequence.shape
T = T//3

x = torch.arange(0, W)*1.
y = 0.5*H - torch.arange(0, H)
xx, yy = torch.meshgrid(x, y, indexing='xy')

In [5]:
for t in range(T):
    tt = t*torch.ones_like(xx)
    X = torch.stack([tt, xx, yy], 0)
    X = torch.flatten(X, start_dim=1).T
    Y = torch.tensor(sequence[:,:,3*t:3*t+3], dtype=X.dtype)
    Y = torch.flatten(Y, end_dim=1)
    if t == 0:
        x_train = X
        y_train = Y
    else:
        x_train = torch.cat([x_train,X], 0)
        y_train = torch.cat([y_train,Y], 0)

In [6]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(x_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

In [7]:
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(3, 64)  # input dim = 3 (t, x, y)
        self.fc2 = nn.Linear(64, 128) # hidden dims = [32, 16]
        self.out = nn.Linear(128, 3)  # output dim = 3 (u, v, p)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.LeakyReLU()(x) 
        x = self.fc2(x)
        x = nn.LeakyReLU()(x) 
        return self.out(x)

def navier_stokes_loss(Re):
    def func(model, X):
        t = X[...,0].reshape(-1,1).requires_grad_(True)
        x = X[...,1].reshape(-1,1).requires_grad_(True)
        y = X[...,2].reshape(-1,1).requires_grad_(True)
        X_ = torch.cat([t,x,y], axis=-1)
        out = model.forward(X_)
        u = out[...,0].reshape(-1,1)
        v = out[...,1].reshape(-1,1)
        p = out[...,2].reshape(-1,1)

        # derivatives
        u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
        u_y = torch.autograd.grad(u.sum(), y, create_graph=True)[0]
        u_t = torch.autograd.grad(u.sum(), t, create_graph=True)[0]
        v_x = torch.autograd.grad(v.sum(), x, create_graph=True)[0]
        v_y = torch.autograd.grad(v.sum(), y, create_graph=True)[0]
        v_t = torch.autograd.grad(v.sum(), t, create_graph=True)[0]
        p_x = torch.autograd.grad(p.sum(), x, create_graph=True)[0]
        p_y = torch.autograd.grad(p.sum(), y, create_graph=True)[0]
        
        u_xx = torch.autograd.grad(u_x.sum(), x, create_graph=True)[0]
        u_yy = torch.autograd.grad(u_y.sum(), y, create_graph=True)[0]
        v_xx = torch.autograd.grad(v_x.sum(), x, create_graph=True)[0]
        v_yy = torch.autograd.grad(v_y.sum(), y, create_graph=True)[0]

        
        f_equation_mass = u_x + v_y
        f_equation_x = u_t + (u * u_x + v * u_y) + p_x - 1.0 / Re * (u_xx + u_yy)
        f_equation_y = v_t + (u * v_x + v * v_y) + p_y - 1.0 / Re * (v_xx + v_yy)

        mse = torch.nn.MSELoss()
        batch_t_zeros = torch.zeros_like(x, dtype=torch.float32, device=device)
        mse_equation = mse(f_equation_x, batch_t_zeros) + mse(f_equation_y, batch_t_zeros) + \
                        mse(f_equation_mass, batch_t_zeros)
        
        return mse_equation
    
    return func


class PINN(nn.Module):
    def __init__(self, model, data_loss=None, physics_loss=None):
        super().__init__()
        
        self.model = model.to(device)
        
        if isinstance(data_loss, list):
            self.data_loss = data_loss
        else:
            self.data_loss = [data_loss]
        if isinstance(physics_loss, list):
            self.physics_loss = physics_loss
        else:
            self.physics_loss = [physics_loss]
        
        for name, param in self.model.named_parameters():
            if name.endswith('linear.weight'):
                nn.init.xavier_normal_(param)
            elif name.endswith('linear.bias'):
                nn.init.zeros_(param)

    def forward(self, x):
        return self.model(x)
    
    def loss(self, x, y):
        out = self.forward(x)

        loss = 0.
        for data_loss in self.data_loss:
            loss += data_loss(out, y)
        for physics_loss in self.physics_loss:
            loss += physics_loss(self.model, x)

        return loss
    
backbone = Backbone()
pinn = PINN(backbone, torch.nn.MSELoss(), navier_stokes_loss(Re=Re)) # TODO: weights, masks, 

In [15]:
N_EPOCHS = 500

optimizer = torch.optim.Adam(pinn.parameters(), lr=0.0001)

for epoch in range(N_EPOCHS):
    running_loss = 0.
    for iter, (x, y) in enumerate(train_dataloader):
        optimizer.zero_grad()
        # pred = pinn(x)
        loss = pinn.loss(x.to(device), y.to(device))
        loss.backward()
        # update weights
        optimizer.step()

        running_loss += loss.cpu().detach().numpy()
        print(f'Epoch {epoch+1}/{N_EPOCHS} -- loss: {running_loss/(iter+1):.5f}', end='\r')
    print('')

Epoch 1/500 -- loss: 285.65147
Epoch 2/500 -- loss: 285.52342
Epoch 3/500 -- loss: 285.32144
Epoch 4/500 -- loss: 285.23067
Epoch 5/500 -- loss: 284.98129
Epoch 6/500 -- loss: 284.77192
Epoch 7/500 -- loss: 284.52443
Epoch 8/500 -- loss: 284.42429
Epoch 9/500 -- loss: 284.20186
Epoch 10/500 -- loss: 283.92728
Epoch 11/500 -- loss: 283.78171
Epoch 12/500 -- loss: 283.53149
Epoch 13/500 -- loss: 283.37082
Epoch 14/500 -- loss: 283.10759
Epoch 15/500 -- loss: 282.92776
Epoch 16/500 -- loss: 282.78331
Epoch 17/500 -- loss: 282.59965
Epoch 18/500 -- loss: 282.36049
Epoch 19/500 -- loss: 282.23138
Epoch 20/500 -- loss: 281.97725
Epoch 21/500 -- loss: 281.76381
Epoch 22/500 -- loss: 281.56669
Epoch 23/500 -- loss: 281.39362
Epoch 24/500 -- loss: 281.03234
Epoch 25/500 -- loss: 280.66627
Epoch 26/500 -- loss: 280.41713
Epoch 27/500 -- loss: 280.18752
Epoch 28/500 -- loss: 280.03346
Epoch 29/500 -- loss: 279.72580
Epoch 30/500 -- loss: 279.60189
Epoch 31/500 -- loss: 279.39382
Epoch 32/500 -- l

KeyboardInterrupt: 

In [16]:
predictions = []
for t in range(T):
    tt = t*torch.ones_like(xx)
    X = torch.stack([tt, xx, yy], 0)
    X = torch.flatten(X, start_dim=1).T
    pred = pinn(X.to(device))
    u = pred[...,0].view(H,W)
    v = pred[...,1].view(H,W)
    p = pred[...,2].view(H,W)
    predictions.append(u.cpu().detach().numpy())
    predictions.append(v.cpu().detach().numpy())
    predictions.append(p.cpu().detach().numpy())
predictions = np.transpose(predictions, [1,2,0])

u = predictions[:,:,::3]
v = predictions[:,:,1::3]
p = predictions[:,:,2::3]
animate(u)